# Udacity Capstone - Sparkify

## Initialize Script

### Module Imports

In [3]:
import boto3
import os
import pandas as pd
import plotly.express as px
import pyspark.ml as sm
import pyspark.ml.classification as smc
import pyspark.ml.evaluation as sme
import pyspark.ml.feature as smf
import pyspark.ml.tuning as smt
import pyspark.sql.functions as ssf
import pyspark.sql.types as sst
import pyspark.sql.window as ssw
import re

from pyspark.sql import SparkSession

### User Input

In [4]:
# Simple toggle, use full/mini dataset
use_full_dataset = True

### Environment Setup

In [5]:
assert isinstance(use_full_dataset, bool), 'Invalid input for use_full_dataset'

In [8]:
# Set up a spark session
spark = SparkSession.builder.appName(
    'Sparkify'
).config(
    'spark.master', 'local[*,4]'
).config(
    'spark.task.maxFailures', '4'
).config(
    'spark.driver.memory', '12g'
).config(
    'spark.executor.memory', '12g'
).getOrCreate()

## Fetch Data

Quick script to fetch sample data from s3. While this script has been tested on AWS, it was ultimately executed on a Local spark deployment.

## Clean Data

### Data Exploration

#### Load Data

In [9]:
data_dir = 'data'
data_file = 'sparkify_event_data.json' if use_full_dataset else 'mini_sparkify_event_data.json'
data_path = f"{data_dir}/{data_file}"

data_raw = spark.read.json(data_path)

data_raw = data_raw.repartition(60, 'userId')

#### Schema exploration

In [10]:
# Check available fields
data_raw.printSchema()

root
 |-- artist: string (nullable = true)
 |-- auth: string (nullable = true)
 |-- firstName: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- itemInSession: long (nullable = true)
 |-- lastName: string (nullable = true)
 |-- length: double (nullable = true)
 |-- level: string (nullable = true)
 |-- location: string (nullable = true)
 |-- method: string (nullable = true)
 |-- page: string (nullable = true)
 |-- registration: long (nullable = true)
 |-- sessionId: long (nullable = true)
 |-- song: string (nullable = true)
 |-- status: long (nullable = true)
 |-- ts: long (nullable = true)
 |-- userAgent: string (nullable = true)
 |-- userId: string (nullable = true)



In [11]:
# Preview data
data_raw.limit(5).toPandas()

Unnamed: 0,artist,auth,firstName,gender,itemInSession,lastName,length,level,location,method,page,registration,sessionId,song,status,ts,userAgent,userId
0,Paramore,Logged In,Logan,M,161,Gregory,218.09587,paid,"Marshall, TX",PUT,NextSong,1537448916000,19480,Ignorance (Album Version),200,1538352015000,"""Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebK...",1390009
1,Joyce Cooling,Logged In,Keyla,F,124,Mcgee,248.11057,paid,"Atlanta-Sandy Springs-Roswell, GA",PUT,NextSong,1536047986000,21942,It's Time I Go (Jazz),200,1538352088000,Mozilla/5.0 (Macintosh; Intel Mac OS X 10.9; r...,1919555
2,Mercury Rev,Logged In,Logan,M,162,Gregory,151.92771,paid,"Marshall, TX",PUT,NextSong,1537448916000,19480,You're My Queen,200,1538352233000,"""Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebK...",1390009
3,I-Roy,Logged In,Keyla,F,125,Mcgee,146.88608,paid,"Atlanta-Sandy Springs-Roswell, GA",PUT,NextSong,1536047986000,21942,Black Is My Color,200,1538352336000,Mozilla/5.0 (Macintosh; Intel Mac OS X 10.9; r...,1919555
4,The Ergs!,Logged In,Logan,M,163,Gregory,8.93342,paid,"Marshall, TX",PUT,NextSong,1537448916000,19480,Sneak Attack,200,1538352384000,"""Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebK...",1390009


#### Health Check

In [None]:
total_rows = data_raw.count()
summaries = []
null_counts = []
for col in data_raw.columns:

    # Get no. of values
    value_counts = data_raw.groupby(col).count()
    value_counts = value_counts.orderBy('count', ascending=False)

    # Make sure null count always comes through
    null_count = value_counts.where(value_counts[col].isNull())
    value_counts = value_counts.where(value_counts[col].isNotNull())

    # Convert output to Pandas df, row limit to prevent any memory issues
    value_counts = value_counts.limit(25).toPandas()
    null_count = null_count.toPandas()

    # Create summary dataframes for selected column
    summary = pd.concat([value_counts, null_count], axis=0, ignore_index=True)
    summary.columns = ['value', 'count']
    summary.loc[:, 'field'] = col
    summary = summary.sort_values(by='count', ascending=False)

    if null_count.empty:
        null_count = pd.DataFrame({
            'value': [None],
            'count': [0],
            'field': [col]
        })
    else:
        null_count.loc[:, 'field'] = col

    # Save output to memory
    summaries.append(summary)
    null_counts.append(null_count)

# Combine summary dataframes for each column
value_summary = pd.concat(summaries, axis=0, ignore_index=True)
null_summary = pd.concat(null_counts, axis=0, ignore_index=True)

# Calculate counts as percentages
value_summary.loc[:, 'percentage'] = value_summary['count']/total_rows
null_summary.loc[:, 'percentage'] = null_summary['count']/total_rows

# Standardize column order
value_summary = value_summary[['field', 'value', 'count', 'percentage']]
null_summary = null_summary[['field', 'value', 'count', 'percentage']]

In [14]:
# Write to disk
value_summary.to_excel('summaries/value_summary.xlsx', index=False)
null_summary.to_excel('summaries/null_summary.xlsx', index=False)

Notes (from running on mini dataset):
* Records with null userId might be useful, need to look at them in more detail
* Some encoding errors, but that shouldn't impact the model
* itemInSession looks useful for feature engineering
* names will need removing
* paid/free info from level will be useful
* what is registration? looks like it might be a timestamp
* http response codes in response could be useful
    * 307 is just a redirect, but 404 would be user-impacting
* user agent could be used to get platform info
* slightly more cancellations submitted than confirmed
* significantly more downgrades started than confirmed

#### Null Investigation

It appears that in the full dataset, the Null userId issue isn't present

In [15]:
# Records with missing user IDs
data_null = data_raw.where(data_raw['userId'].isNull())
null_sample = data_null.limit(1000).toPandas()
null_sample
# No results? Interesting...
# Some userIds are populated with an empty string, not technically Null

Unnamed: 0,artist,auth,firstName,gender,itemInSession,lastName,length,level,location,method,page,registration,sessionId,song,status,ts,userAgent,userId


In [16]:
# Records with empty userIds
data_empty = data_raw.where(data_raw['userId'] == '')
empty_sample = data_empty.limit(1000).toPandas()
empty_sample

Unnamed: 0,artist,auth,firstName,gender,itemInSession,lastName,length,level,location,method,page,registration,sessionId,song,status,ts,userAgent,userId


In [17]:
data_empty.groupby(['page', 'auth']).count().toPandas()
# Looks like these are valid records, but without a userId they aren't useable

Unnamed: 0,page,auth,count


In [18]:
# Records with missing song info
# Probably a valid reason for this, but worth checking just in case
data_null = data_raw.where(
    data_raw['userId'].isNotNull() & data_raw['song'].isNull())
null_sample = data_null.limit(1000).toPandas()
null_sample
# These are still useful records, contains all data where the page isn't nextSong

Unnamed: 0,artist,auth,firstName,gender,itemInSession,lastName,length,level,location,method,page,registration,sessionId,song,status,ts,userAgent,userId
0,,Logged In,Logan,M,164,Gregory,,paid,"Marshall, TX",PUT,Thumbs Up,1537448916000,19480,,307,1538352385000,"""Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebK...",1390009
1,,Logged In,Mikiyah,F,166,Williams,,paid,"New York-Newark-Jersey City, NY-NJ-PA",GET,Home,1535529597000,11733,,200,1538354059000,Mozilla/5.0 (Windows NT 6.1; WOW64; rv:31.0) G...,1178731
2,,Logged In,Kyle,M,0,Johns,,free,"Sacramento--Roseville--Arden-Arcade, CA",GET,Home,1537057337000,13907,,200,1538354097000,"""Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebK...",1809452
3,,Logged In,John,M,2,Anderson,,paid,"New York-Newark-Jersey City, NY-NJ-PA",PUT,Thumbs Down,1526345905000,14956,,307,1538354433000,"""Mozilla/5.0 (Windows NT 6.2; WOW64) AppleWebK...",1855442
4,,Logged In,Keyla,F,137,Mcgee,,paid,"Atlanta-Sandy Springs-Roswell, GA",GET,Downgrade,1536047986000,21942,,200,1538354676000,Mozilla/5.0 (Macintosh; Intel Mac OS X 10.9; r...,1919555
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
995,,Logged In,Jamie,F,83,Wolfe,,paid,"San Francisco-Oakland-Hayward, CA",PUT,Thumbs Down,1531764446000,2368,,307,1538434344000,"""Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebK...",1351140
996,,Logged In,Samuel,M,12,Nelson,,free,"Boulder, CO",GET,Roll Advert,1537455856000,28743,,200,1538434387000,"""Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebK...",1561529
997,,Logged In,Guy,M,56,Robertson,,free,"Providence-Warwick, RI-MA",GET,Roll Advert,1534095159000,14869,,200,1538434664000,"""Mozilla/5.0 (Windows NT 6.3; WOW64) AppleWebK...",1311473
998,,Logged In,Mia,F,52,Medina,,free,"Los Angeles-Long Beach-Anaheim, CA",GET,Error,1538028282000,18711,,404,1538434686000,"""Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_4...",1814851


#### Registration Investigation

In [19]:
# Check theory that registration is just a timestamp
data_reg = data_raw.where(data_raw['page'].isin({
    'Submit Registration',
    'Register'
}))
data_reg = data_reg.limit(1000).toPandas()
data_reg

Unnamed: 0,artist,auth,firstName,gender,itemInSession,lastName,length,level,location,method,page,registration,sessionId,song,status,ts,userAgent,userId
0,,Guest,,,5,,,free,,GET,Register,,15008,,200,1538378337000,,1261737
1,,Guest,,,1,,,free,,GET,Register,,15002,,200,1538418366000,,1261737
2,,Guest,,,2,,,free,,PUT,Submit Registration,,15002,,307,1538418367000,,1261737
3,,Guest,,,12,,,free,,GET,Register,,15002,,200,1538419438000,,1261737
4,,Guest,,,13,,,free,,PUT,Submit Registration,,15002,,307,1538419439000,,1261737
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
995,,Guest,,,4,,,free,,PUT,Submit Registration,,2528,,307,1540986433000,,1261737
996,,Guest,,,2,,,free,,GET,Register,,14089,,200,1541015492000,,1261737
997,,Guest,,,7,,,free,,GET,Register,,14089,,200,1541015915000,,1261737
998,,Guest,,,8,,,free,,PUT,Submit Registration,,14089,,307,1541015916000,,1261737


In [26]:
# Pick a session ID which features in data_reg
# The currently selected ID is from the full dataset, might not be present in the mini version
data_session = data_raw.where(data_raw['sessionId'] == 15002)
data_session = data_session.toPandas()

In [30]:
# Get actual registration time, compare to values in registration column
reg_time = data_session.loc[data_session['page'] == 'Submit Registration', 'ts'].values[-1]
reg_min = data_session['registration'].min()
reg_max = data_session['registration'].max()
reg_time, reg_min, reg_max

(1538419439000, 1535015137000.0, 1538391976000.0)

In [31]:
# View the time difference between actual & recorded registration time
reg_time = pd.Timestamp(reg_time, unit='ms')
reg_min = pd.Timestamp(reg_min, unit='ms')
reg_max = pd.Timestamp(reg_max, unit='ms')
reg_min - reg_time

Timedelta('-40 days +14:21:38')

In [33]:
# Export the session info for evaluation
data_session.to_excel('summaries/sample_session.xlsx', index=False)

Notes
<br><i>Registration time is constant when a user is logged in, but it looks like there can be a difference of several days between submitting a registration and the value
held in the registration column. Hypothesis is that ts could be sourced from the users device while registration is a timestamp generated by the sparkify system. There could also be a batch process updating the registration field? Will tentatively try using the values in this column.</i>

<i>userAgent, userId, location, registration, gender, firstName, lastName can all be filled in based on the sessionId, should improve data availability</i>

### Data Cleaning

#### Fill in the Gaps

In [34]:
# Get values for each session ID
data_gapfill = data_raw.groupby('sessionId').agg(
    # pylint: disable=no-member
    ssf.max('userAgent').alias('userAgent'),
    ssf.max('userId').alias('userId'),
    ssf.max('location').alias('location'),
    ssf.max('registration').alias('registration'),
    ssf.max('gender').alias('gender'),
    ssf.max('firstName').alias('firstName'),
    ssf.max('lastName').alias('lastName')
).persist()

In [35]:
# Drop these columns from the original dataset
data_cleaned = data_raw.drop(
    'userAgent', 'userId', 'location', 'registration', 'gender', 'firstName', 'lastName')

In [36]:
# Merge back to ensure fields are fully populated
data_cleaned = data_cleaned.join(data_gapfill, on='sessionId', how='inner')
data_cleaned.where(data_raw['sessionId'] == 1719).toPandas()

Unnamed: 0,sessionId,artist,auth,itemInSession,length,level,method,page,song,status,ts,userAgent,userId,location,registration,gender,firstName,lastName
0,1719,,Logged Out,5,,paid,GET,Home,,200,1538620370000,Mozilla/5.0 (Macintosh; Intel Mac OS X 10.9; r...,1580801,"Riverside-San Bernardino-Ontario, CA",1537563950000,M,Kenneth,Hart
1,1719,,Logged Out,6,,paid,PUT,Login,,307,1538620371000,Mozilla/5.0 (Macintosh; Intel Mac OS X 10.9; r...,1580801,"Riverside-San Bernardino-Ontario, CA",1537563950000,M,Kenneth,Hart
2,1719,,Logged Out,96,,paid,GET,Home,,200,1538639548000,Mozilla/5.0 (Macintosh; Intel Mac OS X 10.9; r...,1580801,"Riverside-San Bernardino-Ontario, CA",1537563950000,M,Kenneth,Hart
3,1719,,Logged Out,97,,paid,PUT,Login,,307,1538639549000,Mozilla/5.0 (Macintosh; Intel Mac OS X 10.9; r...,1580801,"Riverside-San Bernardino-Ontario, CA",1537563950000,M,Kenneth,Hart
4,1719,,Logged Out,19,,free,GET,Home,,200,1539360022000,Mozilla/5.0 (Macintosh; Intel Mac OS X 10.9; r...,1580801,"Riverside-San Bernardino-Ontario, CA",1537563950000,M,Kenneth,Hart
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
167,1719,Mr. Scruff,Logged In,47,259.44771,free,PUT,NextSong,Shrimp,200,1539364488000,Mozilla/5.0 (Macintosh; Intel Mac OS X 10.9; r...,1580801,"Riverside-San Bernardino-Ontario, CA",1537563950000,M,Kenneth,Hart
168,1719,,Logged In,48,,free,GET,Roll Advert,,200,1539364510000,Mozilla/5.0 (Macintosh; Intel Mac OS X 10.9; r...,1580801,"Riverside-San Bernardino-Ontario, CA",1537563950000,M,Kenneth,Hart
169,1719,The Ruts,Logged In,49,338.96444,free,PUT,NextSong,West One (Shine On Me),200,1539364747000,Mozilla/5.0 (Macintosh; Intel Mac OS X 10.9; r...,1580801,"Riverside-San Bernardino-Ontario, CA",1537563950000,M,Kenneth,Hart
170,1719,Lonnie Gordon,Logged In,50,181.21098,free,PUT,NextSong,Catch You Baby (Steve Pitron & Max Sanna Radio...,200,1539365085000,Mozilla/5.0 (Macintosh; Intel Mac OS X 10.9; r...,1580801,"Riverside-San Bernardino-Ontario, CA",1537563950000,M,Kenneth,Hart


#### Process Text Data

In [37]:
# Get platform from user agent string
platform_getter = ssf.regexp_extract(
    data_cleaned['userAgent'],
    r'[\w\/\.]+ \(([\w\s\.]+);.*\)',
    1
)

data_cleaned = data_cleaned.withColumn('platform', platform_getter)

In [38]:
def extract_state(location):
    # If multiple state codes given, just use the last one
    # Seems to happen with metropolitan areas which cross state lines?
    if not isinstance(location, str):
        return None
    state = location.split(',')[-1]
    state = state.split('-')[-1].strip()
    return state

state_getter = ssf.udf(extract_state, sst.StringType())
data_cleaned = data_cleaned.withColumn('state', state_getter('location'))

In [39]:
# Remove extra tags from song titles, standardize case
# - e.g. Song Title [feat. artist] (Album Version) -->  SONG TITLE

def extract_song(song):

    if not isinstance(song, str):
        return None

    # Remove any trailing brackets
    song = re.sub(r' \[.+\]', '', song)
    song = re.sub(r' \(.+\)', '', song)

    # Remove any non-standard characters
    song = re.sub(r'[^\w\s]+', '', song)

    # Fix any duplicated spaces
    song = re.sub(r'\s\s+', ' ', song)

    # Standardize case, remove trailing whitespace
    song = song.strip().upper()
    return song

song_getter = ssf.udf(extract_song, sst.StringType())
data_cleaned = data_cleaned.withColumn('songCleaned', song_getter('song'))

## EDA

#### Define Churn

In [40]:
# Indicate the exact moment a customer churns
data_cleaned = data_cleaned.withColumn(
    'churnFlag',
    ssf.when(data_cleaned['page'] == 'Cancellation Confirmation', 1).otherwise(0)
).withColumn(
    'downgradeFlag',
    ssf.when(data_cleaned['page'] == 'Submit Downgrade', 1).otherwise(0)
)

In [41]:
# Flag all records corresponding to a churned customer
churned = data_cleaned.where(data_cleaned['churnFlag'] == 1).select('userId').distinct()
churned = churned.toPandas()['userId'].tolist()
data_cleaned = data_cleaned.withColumn(
    'userChurnFlag',
    ssf.when(data_cleaned['userId'].isin(*churned), 1).otherwise(0))

data_cleaned = data_cleaned.persist()

#### Investigate data

In [42]:
# Simple Counts
user_stats = data_cleaned.groupby('userChurnFlag', 'userId').agg(
    #pylint: disable=no-member
    ssf.countDistinct('artist').alias('noArtists'),
    ssf.countDistinct('song').alias('noSongs'),
    ssf.count('song').alias('noPlays'),
    ssf.max('gender').alias('gender'),
    ssf.max('state').alias('state'),
    ssf.max('registration').alias('registration'),
    ssf.mean('length').alias('meanSongLength')
)
user_stats = user_stats.toPandas()
user_stats.loc[:, 'registration'] = user_stats['registration'].map(
    lambda x: pd.Timestamp(x, unit='ms'))

In [43]:
# Items in Session
item_stats = data_cleaned.groupby('userChurnFlag', 'sessionId').agg(
    #pylint: disable=no-member
    ssf.max('itemInSession').alias('sessionLength'),
    ssf.max('platform').alias('platform')
)
item_stats = item_stats.toPandas()

In [44]:
# Page visits
page_stats = data_cleaned.groupby('userChurnFlag', 'page').count()
page_stats = page_stats.toPandas()

In [45]:
# Status codes
status_stats = data_cleaned.groupby('userChurnFlag', 'status').count()
status_stats = status_stats.toPandas()

#### Generate Plots

In [48]:
# Simple counts (continuous)
id_cols = {'userChurnFlag', 'userId'}
cat_cols = {'gender', 'state'}
for user_col in [x for x in user_stats.columns if x not in id_cols.union(cat_cols)]:
    fig = px.histogram(
        user_stats,
        x=user_col,
        color='userChurnFlag',
        barmode='overlay',
        histnorm='percent',
        labels={0: 'No Churn', 1: 'Churn'},
        nbins=50
        )
    fig.write_html(f'figures/{user_col}StatsPerc.html')

    fig = px.histogram(
        user_stats,
        x=user_col,
        color='userChurnFlag',
        barmode='overlay',
        labels={0: 'No Churn', 1: 'Churn'},
        nbins=50)
    fig.write_html(f'figures/{user_col}StatsAbs.html')

In [49]:
# Simple counts (discrete)
group_counts = user_stats.groupby('userChurnFlag').agg(
    totalUsers = ('userId', 'nunique')
).reset_index()
for user_col in cat_cols:
    plot_df = user_stats.groupby(['userChurnFlag', user_col]).agg(
        noUsers = ('userId', 'nunique')
    ).reset_index()
    plot_df = plot_df.merge(group_counts, on='userChurnFlag', how='inner')
    plot_df.loc[:, 'percUsers'] = plot_df['noUsers'] / plot_df['totalUsers']
    plot_df = plot_df.sort_values(by='noUsers', ascending=False)
    plot_df.loc[:, 'userChurnFlag'] = plot_df['userChurnFlag'].astype(bool)

    fig = px.bar(
        plot_df,
        x=user_col,
        y='percUsers',
        color='userChurnFlag',
        barmode='group',
        labels={True: 'Churn', False: 'No Churn'}
    )
    fig.write_html(f'figures/{user_col}StatsPerc.html')

    fig = px.bar(
        plot_df,
        x=user_col,
        y='noUsers',
        color='userChurnFlag',
        barmode='group',
        labels={True: 'Churn', False: 'No Churn'}
    )
    fig.write_html(f'figures/{user_col}StatsAbs.html')

In [50]:
# Session length
fig = px.histogram(
    item_stats,
    x='sessionLength',
    color='userChurnFlag',
    barmode='overlay',
    histnorm='percent',
    labels={0: 'No Churn', 1: 'Churn'},
    nbins=50)
fig.write_html('figures/sessionLengthStatsPerc.html')

fig = px.histogram(
    item_stats,
    x='sessionLength',
    color='userChurnFlag',
    barmode='overlay',
    labels={0: 'No Churn', 1: 'Churn'},
    nbins=50)
fig.write_html('figures/sessionLengthStatsAbs.html')

In [51]:
# Platform
fig = px.histogram(
    item_stats,
    x='platform',
    color='userChurnFlag',
    barmode='group',
    histnorm='percent',
    labels={0: 'No Churn', 1: 'Churn'}
)
fig.write_html('figures/platformStatsPerc.html')

fig = px.histogram(
    item_stats,
    x='platform',
    color='userChurnFlag',
    barmode='group',
    labels={0: 'No Churn', 1: 'Churn'}
)
fig.write_html('figures/platformStatsAbs.html')

In [52]:
# Page visits
page_stats.loc[:, 'userChurnFlag'] = page_stats['userChurnFlag'].astype(bool)
group_sums = page_stats.groupby('userChurnFlag').agg(
    totalVisits = ('count', 'sum')
).reset_index()
page_stats = page_stats.merge(group_sums, on='userChurnFlag', how='inner')
page_stats.loc[:, 'percentage'] = page_stats['count'] / page_stats['totalVisits']
page_stats = page_stats.sort_values(by='count', ascending=False)

fig = px.bar(
    page_stats,
    x='page',
    y='count',
    color='userChurnFlag',
    barmode='group',
)
fig.write_html('figures/pageStatsAbs.html')

fig = px.bar(
    page_stats,
    x='page',
    y='percentage',
    color='userChurnFlag',
    barmode='group',
)
fig.write_html('figures/pageStatsPerc.html')

In [53]:
# Status codes
status_stats.loc[:, 'userChurnFlag'] = status_stats['userChurnFlag'].astype(bool)
group_sums = status_stats.groupby('userChurnFlag').agg(
    totalResponses = ('count', 'sum')
).reset_index()
status_stats = status_stats.merge(group_sums, on='userChurnFlag', how='inner')
status_stats.loc[:, 'percentage'] = status_stats['count'] / status_stats['totalResponses']
status_stats = status_stats.sort_values(by='count', ascending=False)

fig = px.bar(
    status_stats,
    x='status',
    y='count',
    color='userChurnFlag',
    barmode='group',
)
fig.write_html('figures/statusStatsAbs.html')

fig = px.bar(
    status_stats,
    x='status',
    y='percentage',
    color='userChurnFlag',
    barmode='group',
)
fig.write_html('figures/statusStatsPerc.html')

#### Deal with an outlier
This can also be observed in the generated "value_summary.xlsx". Since there are no null userIds in the full dataset, I suspect they've all been mistakenly tagged with userId 1261737.

In [56]:
user_stats.loc[user_stats['noPlays'] == user_stats['noPlays'].max(), :]

Unnamed: 0,userChurnFlag,userId,noArtists,noSongs,noPlays,gender,state,registration,meanSongLength
7282,1,1261737,31967,172570,3341405,M,WY,2018-11-25 09:07:40,248.715717


In [57]:
data_cleaned = data_cleaned.filter('userId != 1261737')

## Feature Engineering

In [61]:
# Trigger evaluation of the datset up to this point
data_cleaned = data_cleaned.orderBy(
    'userID', 'ts', ascending=True
)

#### Static variables (no time dependency)

In [62]:
static_vars = data_cleaned.groupby('userId').agg(
    #pylint: disable=no-member
    ssf.max('gender').alias('gender'),
    ssf.max('ts').alias('lastTs'),
    ssf.min('registration').alias('registration'),
    ssf.max('state').alias('state'),
    ssf.max('churnFlag').alias('userChurnFlag')
)

In [63]:
static_vars = static_vars.withColumn(
    'accountAge',
    static_vars['lastTs']-static_vars['registration'])

static_vars = static_vars.drop('lastTs', 'registration')

#### Dynamic Variables (evaluated over a time window)

In [64]:
def get_modal_system(data_in, name):
    sys_counts = data_in.groupby('userId', 'platform').agg(
        ssf.countDistinct('sessionId').alias('noSessions'))

    # No built-in mode function, need to manually rank using window functions
    # Using row_number rather than rank to avoid ties
    usr_window = ssw.Window.partitionBy(
        sys_counts['userId']
    ).orderBy(
        sys_counts['noSessions'].desc()
    )
    sys_counts = sys_counts.withColumn(
        #pylint: disable=no-member
        'platformRank', ssf.row_number().over(usr_window)
    )

    # Take first ranked platform for each 
    sys_modal = sys_counts.where(
        sys_counts['platformRank'] == 1
    ).select(
        'userId', f'platform'
    )

    sys_modal = sys_modal.withColumnRenamed('platform', f'platform{name}')

    return sys_modal

In [65]:
def get_simple_aggregates(data_in, name):

    flagged = data_in.withColumn(
        'httpError',
        ssf.when(data_in['status']==404, 1).otherwise(0)
    )
    dynamic_vars = flagged.groupby(
        'userId'
    ).agg(
        #pylint: disable=no-member
        ssf.count('platform').alias(f'noPlatforms{name}'),
        ssf.count('userAgent').alias(f'noSystems{name}'),
        ssf.sum('httpError').alias(f'httpErrors{name}'),
        ssf.countDistinct('artist').alias(f'noArtists{name}'),
        ssf.countDistinct('songCleaned').alias(f'noSongs{name}'),
        # Count excludes nulls, so this should be equivalent to counting
        # number of nextSong pages
        ssf.count('songCleaned').alias(f'noPlays{name}'),
        ssf.first('level').alias(f'levelStart{name}'),
        ssf.last('level').alias(f'levelEnd{name}')
    )

    return dynamic_vars

In [66]:
def get_session_stats(data_in, name):
    # Limit to song plays
    session_vars = data_in.filter(data_in['page'] == 'NextSong')

    # Get play time for each song
    session_window = ssw.Window.partitionBy(
        session_vars['userId'], session_vars['sessionId']
    ).orderBy(
        'ts'
    )
    session_vars = session_vars.withColumn(
        'nextPlayStart',
        ssf.lead(session_vars['ts']).over(session_window)
    )

    session_vars = session_vars.withColumn(
        'playTime',
        session_vars['nextPlayStart'] - session_vars['ts']
    )

    # Get play time for each session
    session_vars = session_vars.groupby(
        'userId', 'sessionId'
    ).agg(
        #pylint: disable=no-member
        ssf.sum('playTime').alias('sessionLength')
    )

    # Get play time stats for each user
    session_vars = session_vars.groupby(
        'userId'
    ).agg(
        #pylint: disable=no-member
        ssf.mean('sessionLength').alias(f'lengthMean{name}'),
        ssf.stddev('sessionLength').alias(f'lengthStd{name}'),
        ssf.sum('sessionLength').alias(f'lengthSum{name}')
    )

    return session_vars

In [67]:
def get_popularity_scores(data_in, name):
    # For each song, how much do they contribute to the total number of plays?
    song_data = data_in.filter(data_in['page'] == 'NextSong')
    song_totals = song_data.groupby('songCleaned').agg(
        #pylint: disable=no-member
        ssf.count('ts').alias('songTotal')
    )
    overall_total = song_totals.agg(
        #pylint: disable=no-member
        ssf.sum('songTotal').alias('overallTotal')
    ).collect()[0]['overallTotal']

    song_totals = song_totals.withColumn(
        'overallPerc',
        song_totals['songTotal'] / overall_total
    ).select(
        'songCleaned', 'overallPerc'
    )

    # For each user, how popular are the songs they're playing?
    popularity = song_data.select('userId', 'songCleaned')

    # Make a clone of popularity, bug in pyspark causes errors if this isn't done
    # https://stackoverflow.com/questions/45713290/how-to-resolve-the-analysisexception-resolved-attributes-in-spark
    popularity = spark.createDataFrame(popularity.rdd, popularity.schema)

    popularity = popularity.join(
        song_totals,
        popularity['songCleaned']==song_totals['songCleaned'],
        how='left')
    popularity = popularity.groupby('userId').agg(
        #pylint: disable=no-member
        ssf.sum('overallPerc').alias(f'popularityScore{name}')
    )

    return popularity

In [68]:
def get_page_clicks(data_in, name):

    def format_page(page):
        if not isinstance(page, str):
            return None
        
        page = page.replace(' ', '')
        page = f'{page}{name}'
        return page
    
    page_getter = ssf.udf(format_page, sst.StringType())

    page_clicks = data_in.withColumn('page', page_getter('page'))

    page_clicks = page_clicks.groupby('userId').pivot('page').count()

    # Ensure consistent column headers regardless of the dataset
    included_pages = [
        'About', 'Add Friend', 'Add to Playlist', 'Cancel', 'Downgrade',
        'Error', 'Help', 'Home', 'Login', 'Logout', 'NextSong', 'Register',
        'Roll Advert', 'Save Settings', 'Settings', 'Submit Downgrade', 'Submit Registration',
        'Submit Upgrade', 'Thumbs Down', 'Thumbs Up', 'Upgrade'
    ]

    included_pages = ['userId'] + [format_page(x) for x in included_pages]

    missing_pages = [x for x in included_pages if x not in page_clicks.columns]

    # Add a column for any missing pages
    for missing_page in missing_pages:
        page_clicks = page_clicks.withColumn(
            #pylint: disable=no-member
            missing_page,
            ssf.lit(0)
        )

    # Only bring through what's in the list
    page_clicks = page_clicks.select(*included_pages)

    return page_clicks

In [69]:
def get_page_clicks_alt(data_in, name):
    '''Alernate form of get_page_clicks, avoids use of the pivot function to
    (hopefully) allow the code to run when using the full dataset'''

    # Ensure consistent column headers regardless of the dataset
    # Note that cancellation submission is not included, as this would result
    # in a proxy column that contains the target label
    included_pages = [
        'About', 'Add Friend', 'Add to Playlist', 'Cancel', 'Downgrade',
        'Error', 'Help', 'Home', 'Login', 'Logout', 'NextSong', 'Register',
        'Roll Advert', 'Save Settings', 'Settings', 'Submit Downgrade', 'Submit Registration',
        'Submit Upgrade', 'Thumbs Down', 'Thumbs Up', 'Upgrade'
    ]

    def format_page(page):
        if not isinstance(page, str):
            return None
        
        page = page.replace(' ', '')
        page = f'{page}{name}'
        return page
    
    page_getter = ssf.udf(format_page, sst.StringType())

    page_clicks = data_in.withColumn('page', page_getter('page'))

    page_clicks = page_clicks.select('userId', 'page').dropna(how='any')

    included_pages = [format_page(x) for x in included_pages]

    # Add new field for count of each page
    for included_page in included_pages:
        page_clicks = page_clicks.withColumn(
            included_page,
            ssf.when(page_clicks['page'] == included_page, 1).otherwise(0)
        )
    
    # Set up list of aggregations to be applied
    aggregations = [
        #pylint: disable=no-member
        ssf.sum(col).alias(col)
        for col in included_pages
    ]

    page_clicks = page_clicks.groupby('userId').agg(*aggregations)

    return page_clicks

#### Generate aggregate statistics

In [70]:
# Get the max timestamp
limits = data_cleaned.agg(
    #pylint: disable=no-member
    ssf.max('ts').alias('maxTs'),
    ssf.min('ts').alias('minTs')
).collect()[0]
max_ts = limits['maxTs']
min_ts = limits['minTs']

In [71]:
day_delta = 24 * 60 * 60 * 1000
week_delta = 7 * day_delta
month_delta = 31 * day_delta

In [72]:
# How many months of data do we have?
(max_ts - min_ts) / month_delta

1.9677423088410992

In [73]:
start_times = {
    'week': max_ts - week_delta,
    'month': max_ts - month_delta,
    'twomonth': max_ts - 2*month_delta
}

In [76]:
outputs = None
for name, start_ts in start_times.items():

    name = name.title()

    # Get subset of data
    data_subset = data_cleaned.where(
        data_cleaned['ts'] >= start_ts
    ).persist()

    modal_system = get_modal_system(data_subset, name)

    simple_aggregates = get_simple_aggregates(data_subset, name)

    session_stats = get_session_stats(data_subset, name)

    popularity_scores = get_popularity_scores(data_subset, name)

    page_clicks = get_page_clicks_alt(data_subset, name)

    merged = modal_system.join(
        simple_aggregates,
        on='userId',
        how='full_outer'
    ).join(
        session_stats,
        on='userId',
        how='full_outer'
    ).join(
        popularity_scores,
        on='userId',
        how='full_outer'
    ).join(
        page_clicks,
        on='userId',
        how='full_outer'
    ).persist()

    if outputs is None:
        outputs = merged
    else:
        outputs = outputs.join(merged, on='userId', how='full_outer')

In [77]:
merged = static_vars.join(
    outputs,
    on='userId',
    how='full_outer'
)

platform_cols = [x for x in outputs.columns if 'platform' in x]
level_cols = [x for x in outputs.columns if 'level' in x]
merged = merged.fillna('unknown', subset=platform_cols)
merged = merged.fillna('free', subset=level_cols)
merged = merged.fillna(0)

#### One-hot Encoding

In [78]:
def fill_empty_string(string_in, fill_value='unknown'):
    if not isinstance(string_in, str):
        return fill_value
    elif not string_in:
        return fill_value
    else:
        return string_in

na_handler = ssf.udf(fill_empty_string, sst.StringType())

In [81]:
cat_cols = ['gender', 'state', 'platform', 'level']
cat_cols = [x for x in merged.columns if any((y for y in cat_cols if y in x))]

indexers = {}
for cat_col in cat_cols:
    merged = merged.withColumn(cat_col, na_handler(cat_col))
    indexer = smf.StringIndexer(
        inputCol=cat_col,
        outputCol=f"{cat_col}Inx"
    )
    indexer = indexer.fit(merged)
    merged = indexer.transform(merged)
    merged = merged.drop(cat_col).withColumnRenamed(f"{cat_col}Inx", cat_col)
    indexers[cat_col] = indexer

In [82]:
encoder = smf.OneHotEncoder(
    inputCols=cat_cols,
    outputCols=[f"{x}Vec" for x in cat_cols]
)
encoder = encoder.fit(merged)
encoded = encoder.transform(merged)
encoded = encoded.drop(*cat_cols)

encoded = encoded.persist()

## Modeling

#### Set up Pipeline

In [83]:
# Use VectorAssembler to combine all features into a single vector
feature_cols = [x for x in encoded.columns if x not in {'userId', 'userChurnFlag'}]
assembler = smf.VectorAssembler(
    inputCols=feature_cols,
    outputCol='features')
encoded = assembler.transform(encoded)
encoded = encoded.drop(*feature_cols)
encoded = encoded.withColumnRenamed('userChurnFlag', 'label')
encoded = encoded.persist()

encoded_sample = encoded.limit(5).toPandas()
encoded_sample

Unnamed: 0,userId,label,features
0,1000280,0,"(2963223000.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0..."
1,1002185,0,"(4257542000.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0..."
2,1030587,0,"(10689175000.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
3,1033297,0,"(7635920000.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0..."
4,1057724,0,"(8298249000.0, 3.0, 3.0, 0.0, 1.0, 1.0, 1.0, 0..."


In [84]:
# Split out validation dataset
train, val = encoded.randomSplit([3.0, 1.0], seed=42)

In [85]:
# Set up scaler for numerical features
scaler = smf.StandardScaler(
    withStd=True,
    withMean=False,
    inputCol='features',
    outputCol='scaledFeatures')

In [86]:
# Use PCA to reduce dimensionality of scaled vectors
reducer = smf.PCA(
    k=10,
    inputCol=scaler.getOutputCol(),
    outputCol='selectedFeatures')

In [87]:
# Use a classifier to generate the final predictions
classifier = smc.GBTClassifier(
    labelCol='label',
    featuresCol=reducer.getOutputCol(),
    predictionCol='predictedLabel'
)

In [90]:
# Combine all steps in a pipeline
pipeline = sm.Pipeline(
    stages=[scaler, reducer, classifier]
)

In [91]:
# Create an evaluator which will quantify model performance
eval_f1 = sme.MulticlassClassificationEvaluator(
    labelCol='label',
    predictionCol='predictedLabel',
    metricName='f1'
)

In [92]:
# Set up a parameter grid for cross validation
param_grid = smt.ParamGridBuilder().addGrid(
    reducer.k, [10, 20, 50, 75]
).addGrid(
    classifier.maxDepth, [2, 5, 10]
).addGrid(
    classifier.subsamplingRate, [0.1, 0.2, 0.3]
).build()

In [93]:
# Bring everything together
validator = smt.CrossValidator(
    estimator=pipeline,
    estimatorParamMaps=param_grid,
    evaluator=eval_f1,
    numFolds=3
)

#### Fit the Model to the Data

In [94]:
model = validator.fit(train)
model.save('final_model.bin')

In [95]:
train_predictions = model.transform(train)
val_predictions = model.transform(val)

#### Evaluate Model Performance

In [96]:
model.getEstimatorParamMaps()

[{Param(parent='PCA_f9413147e3f8', name='k', doc='the number of principal components'): 10,
  Param(parent='GBTClassifier_ca033f5e50db', name='maxDepth', doc='Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.'): 2,
  Param(parent='GBTClassifier_ca033f5e50db', name='subsamplingRate', doc='Fraction of the training data used for learning each decision tree, in range (0, 1].'): 0.1},
 {Param(parent='PCA_f9413147e3f8', name='k', doc='the number of principal components'): 10,
  Param(parent='GBTClassifier_ca033f5e50db', name='maxDepth', doc='Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.'): 2,
  Param(parent='GBTClassifier_ca033f5e50db', name='subsamplingRate', doc='Fraction of the training data used for learning each decision tree, in range (0, 1].'): 0.2},
 {Param(parent='PCA_f9413147e3f8', name='k', doc='the number of principal components'): 10,
  Param(parent='

In [97]:
train_f1 = eval_f1.evaluate(train_predictions)
val_f1 = eval_f1.evaluate(val_predictions)
print(f"Train F1: {train_f1:,.2f}\nValidation F1: {val_f1:,.2f}")

Train F1: 0.87
Validation F1: 0.86


In [98]:
eval_roc = sme.BinaryClassificationEvaluator(
    labelCol='label',
    rawPredictionCol='predictedLabel',
    metricName='areaUnderROC'
)

train_auc = eval_roc.evaluate(train_predictions)
val_auc = eval_roc.evaluate(val_predictions)
print(f"Train AUC: {train_auc:,.2f}\nValidation AUC: {val_auc:,.2f}")

Train AUC: 0.73
Validation AUC: 0.72


In [99]:
eval_accuracy = sme.MulticlassClassificationEvaluator(
    labelCol='label',
    predictionCol='predictedLabel',
    metricName='accuracy'
)

train_acc = eval_accuracy.evaluate(train_predictions)
val_acc = eval_accuracy.evaluate(val_predictions)
print(f"Train Accuracy: {train_acc:,.2f}\nValidation Accuracy: {val_acc:,.2f}")

Train Accuracy: 0.88
Validation Accuracy: 0.88


In [100]:
eval_precision = sme.MulticlassClassificationEvaluator(
    labelCol='label',
    predictionCol='predictedLabel',
    metricName='weightedPrecision'
)

train_prec = eval_precision.evaluate(train_predictions)
val_prec = eval_precision.evaluate(val_predictions)
print(f"Train Precision: {train_prec:,.2f}\nValidation Precision: {val_prec:,.2f}")

Train Precision: 0.89
Validation Precision: 0.88


In [101]:
eval_recall = sme.MulticlassClassificationEvaluator(
    labelCol='label',
    predictionCol='predictedLabel',
    metricName='weightedRecall'
)

train_rec = eval_recall.evaluate(train_predictions)
val_rec = eval_recall.evaluate(val_predictions)
print(f"Train Recall: {train_rec:,.2f}\nValidation Recall: {val_rec:,.2f}")

Train Recall: 0.88
Validation Recall: 0.88
