In [1]:
from pyspark.sql import SparkSession

Initialize a spark session

In [2]:
spark = SparkSession.builder.master("local[*]").appName("sparkify").getOrCreate()
spark

Read the data

In [3]:
data = spark.read.json("../data/mini_sparkify_event_data.json")

Look at the columns

In [4]:
data.printSchema()

root
 |-- artist: string (nullable = true)
 |-- auth: string (nullable = true)
 |-- firstName: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- itemInSession: long (nullable = true)
 |-- lastName: string (nullable = true)
 |-- length: double (nullable = true)
 |-- level: string (nullable = true)
 |-- location: string (nullable = true)
 |-- method: string (nullable = true)
 |-- page: string (nullable = true)
 |-- registration: long (nullable = true)
 |-- sessionId: long (nullable = true)
 |-- song: string (nullable = true)
 |-- status: long (nullable = true)
 |-- ts: long (nullable = true)
 |-- userAgent: string (nullable = true)
 |-- userId: string (nullable = true)



Look at the data, we will do a vertical show

In [5]:
data.show(vertical=True, n=2)

-RECORD 0-----------------------------
 artist        | Martha Tilston       
 auth          | Logged In            
 firstName     | Colin                
 gender        | M                    
 itemInSession | 50                   
 lastName      | Freeman              
 length        | 277.89016            
 level         | paid                 
 location      | Bakersfield, CA      
 method        | PUT                  
 page          | NextSong             
 registration  | 1538173362000        
 sessionId     | 29                   
 song          | Rockpools            
 status        | 200                  
 ts            | 1538352117000        
 userAgent     | Mozilla/5.0 (Wind... 
 userId        | 30                   
-RECORD 1-----------------------------
 artist        | Five Iron Frenzy     
 auth          | Logged In            
 firstName     | Micah                
 gender        | M                    
 itemInSession | 79                   
 lastName      | Long    

The most important column in the data set seems to be the `Page` column. It holds all Sparkify pages that the customers have visited. 

In [6]:
data.select("page").dropDuplicates().show()

+--------------------+
|                page|
+--------------------+
|              Cancel|
|    Submit Downgrade|
|         Thumbs Down|
|                Home|
|           Downgrade|
|         Roll Advert|
|              Logout|
|       Save Settings|
|Cancellation Conf...|
|               About|
| Submit Registration|
|            Settings|
|               Login|
|            Register|
|     Add to Playlist|
|          Add Friend|
|            NextSong|
|           Thumbs Up|
|                Help|
|             Upgrade|
+--------------------+
only showing top 20 rows



From the list above we can use the `Cancellation Confirmation` to indicate if a particular user churned or not. We are building a model that predicts if a user is going to churn given the history of interactions with the service.

Therfore, it is important to understand the customers that reached the `Cancellation Confirmation` page. These are the customers that churned, and the task is to build a prediction model that can recognize them.

#### Data cleaning

Now that we indentified our "goal", we can let go of some of the columns that are not needed for further analysis.

In [7]:
data = data.drop(*['artist','firstName', 'lastName', 'id_copy'])
data.show(vertical=True, n=2)

-RECORD 0-----------------------------
 auth          | Logged In            
 gender        | M                    
 itemInSession | 50                   
 length        | 277.89016            
 level         | paid                 
 location      | Bakersfield, CA      
 method        | PUT                  
 page          | NextSong             
 registration  | 1538173362000        
 sessionId     | 29                   
 song          | Rockpools            
 status        | 200                  
 ts            | 1538352117000        
 userAgent     | Mozilla/5.0 (Wind... 
 userId        | 30                   
-RECORD 1-----------------------------
 auth          | Logged In            
 gender        | M                    
 itemInSession | 79                   
 length        | 236.09424            
 level         | free                 
 location      | Boston-Cambridge-... 
 method        | PUT                  
 page          | NextSong             
 registration  | 15383316

Since we are building a model that focuses on a user, remove any null/na values in the userID

In [8]:
from pyspark.sql.functions import isnan, isnull

In [9]:
data.filter((isnan(data['userId'])) | (data['userId'].isNull()) | (data['userId'] == "")).count()

8346

In [10]:
data = data.dropna(how = 'any', subset = ['userId'])
data = data.filter(data["userId"] != "") 
data = data.filter(data['userId'].isNotNull())
data.filter((isnan(data['userId'])) | (data['userId'].isNull()) | (data['userId'] == "")).count()

0

The times in `registration` and `ts` column are given in milliseconds, we will convert those to seconds.

In [11]:
from pyspark.sql.types import DoubleType
import pyspark.sql.functions as F

In [12]:
time_unit_udf = F.udf(lambda x: float(x/1000), DoubleType())
data = data.withColumn("registration", time_unit_udf("registration")). \
    withColumn("ts", time_unit_udf("ts"))

data.select('registration', 'ts').show()

+-------------+-------------+
| registration|           ts|
+-------------+-------------+
|1.538173362E9|1.538352117E9|
| 1.53833163E9| 1.53835218E9|
|1.538173362E9|1.538352394E9|
| 1.53833163E9|1.538352416E9|
|1.538173362E9|1.538352676E9|
| 1.53833163E9|1.538352678E9|
| 1.53833163E9|1.538352886E9|
|1.538173362E9|1.538352899E9|
|1.538173362E9|1.538352905E9|
|1.538173362E9|1.538353084E9|
| 1.53833163E9|1.538353146E9|
| 1.53833163E9| 1.53835315E9|
|1.538173362E9|1.538353218E9|
| 1.53833163E9|1.538353375E9|
| 1.53833163E9|1.538353376E9|
|1.538173362E9|1.538353441E9|
| 1.53833163E9|1.538353576E9|
|1.537365219E9|1.538353668E9|
|1.538173362E9|1.538353687E9|
| 1.53833163E9|1.538353744E9|
+-------------+-------------+
only showing top 20 rows



#### Label

As mentioned above we are deriving the label `churn` as a function of the `Page` column. There are only two values to churn, churned or not churned. Therefore this is going to be a boolean column represented by 1 for churned and 0 for not churned.

In [13]:
from pyspark.sql.types import IntegerType

In [14]:
cancelation_event_udf = F.udf(lambda x: 1 if x == "Cancellation Confirmation" else 0, IntegerType())
data = data.withColumn("churn", cancelation_event_udf("page"))
data.filter(data['userId'] == 125).describe('churn').show()

+-------+-------------------+
|summary|              churn|
+-------+-------------------+
|  count|                 11|
|   mean|0.09090909090909091|
| stddev|0.30151134457776363|
|    min|                  0|
|    max|                  1|
+-------+-------------------+



The problem here is that a particular user will have multiple entries, since this is a database of user activities log. A user may have visitied number of other pages before reaching the "Cancellation" page. For a user that has churned, we want to make sure that all of his/her logs record 1 for the churn column. For this we will use the Window function from `sql.window`.


In [15]:
from pyspark.sql.window import Window
from pyspark.sql.functions import sum as Fsum

In [16]:
window = Window.partitionBy("userId") \
        .rangeBetween(Window.unboundedPreceding,
                      Window.unboundedFollowing)
data = data.withColumn("churn", Fsum("churn").over(window))


In [17]:
data.filter(data['userId'] == 125).describe('churn').show()

+-------+-----+
|summary|churn|
+-------+-----+
|  count|   11|
|   mean|  1.0|
| stddev|  0.0|
|    min|    1|
|    max|    1|
+-------+-----+



Let's see how many unique users we have

In [18]:
data.select("userID").dropDuplicates().count()

225

The total churn instances, again a user has multiple entries here

In [19]:
data.filter(data['churn'] == 1).count()

44864

Now let's create a label dataframe that has one entry per each user

In [20]:
from pyspark.sql.functions import col

In [21]:
label = data \
    .select('userId', col('churn').alias('label')) \
    .dropDuplicates()
label.show()

+------+-----+
|userId|label|
+------+-----+
|100010|    0|
|200002|    0|
|   125|    1|
|   124|    0|
|    51|    1|
|     7|    0|
|    15|    0|
|    54|    1|
|   155|    0|
|100014|    1|
|   132|    0|
|   154|    0|
|   101|    1|
|    11|    0|
|   138|    0|
|300017|    0|
|100021|    1|
|    29|    1|
|    69|    0|
|   112|    0|
+------+-----+
only showing top 20 rows



#### Feature engineering

Users that have been using the service tend to stay with the service, even as a paying customer, than those that recently signed up.

Every company selling products and services to customers have an idea of a "lifetime value" of a customer. With that in mind we create our first feature.

We converted the times to seconds above. After gaining the lifetime of a user, we convert that to days.

In [22]:
from pyspark.sql.functions import sum as Fsum, col

In [23]:
time_since_registration = data \
    .select('userId', 'registration', 'ts') \
    .withColumn('lifetime', (data.ts - data.registration)) \
    .groupBy('userId') \
    .agg({'lifetime': 'max'}) \
    .withColumnRenamed('max(lifetime)', 'lifetime') \
    .select('userId', (col('lifetime') / 3600 / 24).alias('lifetime'))

In [24]:
time_since_registration.describe('lifetime').show()

+-------+-------------------+
|summary|           lifetime|
+-------+-------------------+
|  count|                225|
|   mean|   79.8456834876543|
| stddev|  37.66147001861254|
|    min|0.31372685185185184|
|    max|  256.3776736111111|
+-------+-------------------+



Bussiness big and small not only rely on their customers coming back for more good and services, but also referring the bussiness to their friends and family. It is in our nature to refer things that we like. With this idea in mind we will create another feature.

Again we will user the "Page" column specifically looking for the `Add Friend` page.

In [25]:
referring_friends = data \
    .select('userID', 'page') \
    .where(data.page == 'Add Friend') \
    .groupBy('userID') \
    .count() \
    .withColumnRenamed('count', 'add_friend')

In [26]:
referring_friends.describe('add_friend').show()

+-------+------------------+
|summary|        add_friend|
+-------+------------------+
|  count|               206|
|   mean|20.762135922330096|
| stddev|20.646779074405007|
|    min|                 1|
|    max|               143|
+-------+------------------+



#### Modeling