# Udacity Capstone - Sparkify

## Initialize Script

### Module Imports

In [1]:
import boto3
import numpy as np
import os
import pandas as pd
import plotly.figure_factory as ff
import plotly.express as px
import pyspark.ml as sm
import pyspark.ml.classification as smc
import pyspark.ml.evaluation as sme
import pyspark.ml.feature as smf
import pyspark.ml.tuning as smt
import pyspark.sql.functions as ssf
import pyspark.sql.types as sst
import pyspark.sql.window as ssw
import re

from pyspark.sql import SparkSession
from sklearn.metrics import confusion_matrix

### User Input
Running this script on the full dataset takes ~3 hours on my own local spark instance, suggest using the mini dataset for testing.

In [2]:
# Simple toggle, use full/mini dataset
use_full_dataset = True

### Environment Setup

In [3]:
assert isinstance(use_full_dataset, bool), 'Invalid input for use_full_dataset'

In [4]:
# Set up a spark session
spark = SparkSession.builder.appName(
    'Sparkify'
).config(
    # Use all cores, max 4 retries per task
    'spark.master', 'local[*,4]'
).config(
    # Max 4 retries per task
    'spark.task.maxFailures', '4'
).config(
    # Required to prevent out-of-memory issues
    'spark.driver.memory', '24g'
).config(
    # Required to prevent out-of-memory issues
    'spark.executor.memory', '24g'
).getOrCreate()

## Fetch Data

Quick script to fetch sample data from s3. While this script has been tested on AWS, it was ultimately executed on a Local spark deployment.
<br>Only run this section once, the download takes a while to complete.

## Clean Data

### Data Exploration

#### Load Data

In [5]:
# Set target file to laod
data_dir = 'data'
data_file = 'sparkify_event_data.json' if use_full_dataset else 'mini_sparkify_event_data.json'
data_path = f"{data_dir}/{data_file}"

# Load to spark dataframe, partition by userId
data_raw = spark.read.json(data_path)
data_raw = data_raw.repartition(60, 'userId')

#### Schema exploration

In [6]:
# Check available fields
data_raw.printSchema()

root
 |-- artist: string (nullable = true)
 |-- auth: string (nullable = true)
 |-- firstName: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- itemInSession: long (nullable = true)
 |-- lastName: string (nullable = true)
 |-- length: double (nullable = true)
 |-- level: string (nullable = true)
 |-- location: string (nullable = true)
 |-- method: string (nullable = true)
 |-- page: string (nullable = true)
 |-- registration: long (nullable = true)
 |-- sessionId: long (nullable = true)
 |-- song: string (nullable = true)
 |-- status: long (nullable = true)
 |-- ts: long (nullable = true)
 |-- userAgent: string (nullable = true)
 |-- userId: string (nullable = true)



In [7]:
# Preview data
data_raw.limit(5).toPandas()

Unnamed: 0,artist,auth,firstName,gender,itemInSession,lastName,length,level,location,method,page,registration,sessionId,song,status,ts,userAgent,userId
0,Paramore,Logged In,Logan,M,161,Gregory,218.09587,paid,"Marshall, TX",PUT,NextSong,1537448916000,19480,Ignorance (Album Version),200,1538352015000,"""Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebK...",1390009
1,Joyce Cooling,Logged In,Keyla,F,124,Mcgee,248.11057,paid,"Atlanta-Sandy Springs-Roswell, GA",PUT,NextSong,1536047986000,21942,It's Time I Go (Jazz),200,1538352088000,Mozilla/5.0 (Macintosh; Intel Mac OS X 10.9; r...,1919555
2,Mercury Rev,Logged In,Logan,M,162,Gregory,151.92771,paid,"Marshall, TX",PUT,NextSong,1537448916000,19480,You're My Queen,200,1538352233000,"""Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebK...",1390009
3,I-Roy,Logged In,Keyla,F,125,Mcgee,146.88608,paid,"Atlanta-Sandy Springs-Roswell, GA",PUT,NextSong,1536047986000,21942,Black Is My Color,200,1538352336000,Mozilla/5.0 (Macintosh; Intel Mac OS X 10.9; r...,1919555
4,The Ergs!,Logged In,Logan,M,163,Gregory,8.93342,paid,"Marshall, TX",PUT,NextSong,1537448916000,19480,Sneak Attack,200,1538352384000,"""Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebK...",1390009


#### Health Check

In [8]:
total_rows = data_raw.count()
summaries = []
null_counts = []
for col in data_raw.columns:

    # Get no. of values
    value_counts = data_raw.groupby(col).count()
    value_counts = value_counts.orderBy('count', ascending=False)

    # Make sure null count always comes through
    null_count = value_counts.where(value_counts[col].isNull())
    value_counts = value_counts.where(value_counts[col].isNotNull())

    # Convert output to Pandas df, row limit to prevent any memory issues
    value_counts = value_counts.limit(25).toPandas()
    null_count = null_count.toPandas()

    # Create summary dataframes for selected column
    summary = pd.concat([value_counts, null_count], axis=0, ignore_index=True)
    summary.columns = ['value', 'count']
    summary.loc[:, 'field'] = col
    summary = summary.sort_values(by='count', ascending=False)

    if null_count.empty:
        null_count = pd.DataFrame({
            'value': [None],
            'count': [0],
            'field': [col]
        })
    else:
        null_count.loc[:, 'field'] = col

    # Save output to memory
    summaries.append(summary)
    null_counts.append(null_count)

# Combine summary dataframes for each column
value_summary = pd.concat(summaries, axis=0, ignore_index=True)
null_summary = pd.concat(null_counts, axis=0, ignore_index=True)

# Calculate counts as percentages
value_summary.loc[:, 'percentage'] = value_summary['count']/total_rows
null_summary.loc[:, 'percentage'] = null_summary['count']/total_rows

# Standardize column order
value_summary = value_summary[['field', 'value', 'count', 'percentage']]
null_summary = null_summary[['field', 'value', 'count', 'percentage']]

In [9]:
# Write to disk
value_summary.to_excel('summaries/value_summary.xlsx', index=False)
null_summary.to_excel('summaries/null_summary.xlsx', index=False)

Notes (from running on mini dataset):
* Records with null userId might be useful, need to look at them in more detail
* Some encoding errors, but that shouldn't impact the model
* itemInSession looks useful for feature engineering
* names will need removing
* paid/free info from level will be useful
* what is registration? looks like it might be a timestamp
* http response codes in response could be useful
    * 307 is just a redirect, but 404 would be user-impacting
* user agent could be used to get platform info
* slightly more cancellations submitted than confirmed
* significantly more downgrades started than confirmed

Notes from full dataset
* It looks like all null values have userId = 1261737 in this dataset?
* No truly null values in this dataset

In [10]:
# Records with null user IDs
data_null = data_raw.where(data_raw['userId'].isNull())
null_sample = data_null.limit(5).toPandas()
null_sample
# No results? Interesting...
# Some userIds are populated with an empty string, not technically Null

Unnamed: 0,artist,auth,firstName,gender,itemInSession,lastName,length,level,location,method,page,registration,sessionId,song,status,ts,userAgent,userId


In [11]:
# Records with empty userIds
if use_full_dataset:
    # As noted above, this userId has ~778k records in the full
    # dataset and is obviously anomalous
    data_empty = data_raw.where(data_raw['userId'] == 1261737)
else:
    # Just search for empty strings in the mini dataset
    data_empty = data_raw.where(data_raw['userId'] == '')
empty_sample = data_empty.limit(5).toPandas()
empty_sample

Unnamed: 0,artist,auth,firstName,gender,itemInSession,lastName,length,level,location,method,page,registration,sessionId,song,status,ts,userAgent,userId
0,,Logged Out,,,87,,,paid,,GET,Home,,8615,,200,1538352008000,,1261737
1,,Logged Out,,,0,,,free,,PUT,Login,,7433,,307,1538352041000,,1261737
2,,Logged Out,,,4,,,free,,GET,Home,,25003,,200,1538352182000,,1261737
3,,Logged Out,,,2,,,free,,GET,Home,,9930,,200,1538352254000,,1261737
4,,Logged Out,,,3,,,free,,PUT,Login,,9930,,307,1538352255000,,1261737


In [12]:
# Check which pages these records correspond do
data_empty.groupby(['page', 'auth']).count().toPandas()

Unnamed: 0,page,auth,count
0,Submit Registration,Guest,401
1,Error,Logged Out,840
2,Login,Logged Out,296350
3,Help,Logged Out,25023
4,Error,Guest,74
5,Home,Guest,1653
6,Help,Guest,629
7,Home,Logged Out,408325
8,About,Logged Out,43747
9,Register,Guest,802


<b>Note:</b> It looks like these are valid records, but without a userId they can't be used

In [13]:
# Records with missing song info
# Probably a valid reason for this, but worth checking just in case
data_null = data_raw.where(
    data_raw['userId'].isNotNull() & data_raw['song'].isNull())
null_sample = data_null.limit(5).toPandas()
null_sample

Unnamed: 0,artist,auth,firstName,gender,itemInSession,lastName,length,level,location,method,page,registration,sessionId,song,status,ts,userAgent,userId
0,,Logged In,Logan,M,164,Gregory,,paid,"Marshall, TX",PUT,Thumbs Up,1537448916000,19480,,307,1538352385000,"""Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebK...",1390009
1,,Logged In,Mikiyah,F,166,Williams,,paid,"New York-Newark-Jersey City, NY-NJ-PA",GET,Home,1535529597000,11733,,200,1538354059000,Mozilla/5.0 (Windows NT 6.1; WOW64; rv:31.0) G...,1178731
2,,Logged In,Kyle,M,0,Johns,,free,"Sacramento--Roseville--Arden-Arcade, CA",GET,Home,1537057337000,13907,,200,1538354097000,"""Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebK...",1809452
3,,Logged In,John,M,2,Anderson,,paid,"New York-Newark-Jersey City, NY-NJ-PA",PUT,Thumbs Down,1526345905000,14956,,307,1538354433000,"""Mozilla/5.0 (Windows NT 6.2; WOW64) AppleWebK...",1855442
4,,Logged In,Keyla,F,137,Mcgee,,paid,"Atlanta-Sandy Springs-Roswell, GA",GET,Downgrade,1536047986000,21942,,200,1538354676000,Mozilla/5.0 (Macintosh; Intel Mac OS X 10.9; r...,1919555


<b>Note:</b> These are still useful records, they contain all data where the page isn't nextSong

#### Registration Investigation

In [14]:
# Check theory that registration is just a timestamp
data_reg = data_raw.where(data_raw['page'].isin({
    'Submit Registration',
    'Register'
}))
data_reg = data_reg.limit(5).toPandas()
data_reg

Unnamed: 0,artist,auth,firstName,gender,itemInSession,lastName,length,level,location,method,page,registration,sessionId,song,status,ts,userAgent,userId
0,,Guest,,,5,,,free,,GET,Register,,15008,,200,1538378337000,,1261737
1,,Guest,,,1,,,free,,GET,Register,,15002,,200,1538418366000,,1261737
2,,Guest,,,2,,,free,,PUT,Submit Registration,,15002,,307,1538418367000,,1261737
3,,Guest,,,12,,,free,,GET,Register,,15002,,200,1538419438000,,1261737
4,,Guest,,,13,,,free,,PUT,Submit Registration,,15002,,307,1538419439000,,1261737


In [15]:
# Pick a session ID which features in data_reg
session_id = 15002 if use_full_dataset else 1719
data_session = data_raw.where(data_raw['sessionId'] == session_id)
data_session = data_session.toPandas()

In [16]:
# Get actual registration time, compare to values in registration column
reg_time = data_session.loc[data_session['page'] == 'Submit Registration', 'ts'].values[-1]
reg_min = data_session['registration'].min()
reg_max = data_session['registration'].max()
reg_time, reg_min, reg_max

(1538419439000, 1535015137000.0, 1538391976000.0)

In [17]:
# View the time difference between actual & recorded registration time
reg_time = pd.Timestamp(reg_time, unit='ms')
reg_min = pd.Timestamp(reg_min, unit='ms')
reg_max = pd.Timestamp(reg_max, unit='ms')
reg_min - reg_time

Timedelta('-40 days +14:21:38')

In [18]:
# Export the session info for evaluation
data_session.to_excel('summaries/sample_session.xlsx', index=False)

Notes
<br><i>Registration time is constant when a user is logged in, but it looks like there can be a difference of several days between submitting a registration and the value
held in the registration column. Hypothesis is that ts could be sourced from the users device while registration is a timestamp generated by the sparkify system. There could also be a batch process updating the registration field? Will tentatively try using the values in this column.</i>

<i>userAgent, userId, location, registration, gender, firstName, lastName can all be filled in based on the sessionId, should improve data availability</i>

### Data Cleaning

#### Remove userId 1261737 from full dataset

In [19]:
if use_full_dataset:
    data_raw = data_raw.withColumn(
        'userId',
        ssf.when(data_raw['userId'] == 1261737, None).otherwise(data_raw['userId'])
    )

#### Fill in the Gaps

In [20]:
# Get non-null values for each session ID
data_gapfill = data_raw.groupby('sessionId').agg(
    # pylint: disable=no-member
    ssf.max('userAgent').alias('userAgent'),
    ssf.max('userId').alias('userId'),
    ssf.max('location').alias('location'),
    ssf.max('registration').alias('registration'),
    ssf.max('gender').alias('gender'),
    ssf.max('firstName').alias('firstName'),
    ssf.max('lastName').alias('lastName')
).persist()

In [21]:
# Drop these columns from the original dataset
data_cleaned = data_raw.drop(
    'userAgent', 'userId', 'location', 'registration', 'gender', 'firstName', 'lastName')

In [22]:
# Merge back to ensure fields are fully populated
data_cleaned = data_cleaned.join(data_gapfill, on='sessionId', how='inner')
data_cleaned.where(data_raw['sessionId'] == session_id).toPandas().sample(5)

Unnamed: 0,sessionId,artist,auth,itemInSession,length,level,method,page,song,status,ts,userAgent,userId,location,registration,gender,firstName,lastName
143,15002,Judas Priest,Logged In,128,210.12853,paid,PUT,NextSong,Living After Midnight,200,1541079909000,Mozilla/5.0 (Windows NT 6.1; WOW64; rv:31.0) G...,1745135,"Riverside-San Bernardino-Ontario, CA",1538391976000,M,Kyler,Wall
87,15002,I'm From Barcelona,Logged In,72,175.41179,paid,PUT,NextSong,Rec & Play,200,1541069422000,Mozilla/5.0 (Windows NT 6.1; WOW64; rv:31.0) G...,1745135,"Riverside-San Bernardino-Ontario, CA",1538391976000,M,Kyler,Wall
41,15002,Justin Bieber,Logged In,18,191.55546,paid,PUT,NextSong,Love Me,200,1541061240000,Mozilla/5.0 (Windows NT 6.1; WOW64; rv:31.0) G...,1745135,"Riverside-San Bernardino-Ontario, CA",1538391976000,M,Kyler,Wall
232,15002,System of a Down,Logged In,49,164.15302,paid,PUT,NextSong,She's Like Heroin,200,1539907218000,Mozilla/5.0 (Windows NT 6.1; WOW64; rv:31.0) G...,1745135,"Riverside-San Bernardino-Ontario, CA",1538391976000,M,Kyler,Wall
302,15002,Muse,Logged In,42,380.86485,paid,PUT,NextSong,Space Dementia,200,1539976852000,Mozilla/5.0 (Windows NT 6.1; WOW64; rv:31.0) G...,1745135,"Riverside-San Bernardino-Ontario, CA",1538391976000,M,Kyler,Wall


#### Process Text Data

In [23]:
# Get platform from user agent string
platform_getter = ssf.regexp_extract(
    data_cleaned['userAgent'],
    r'[\w\/\.]+ \(([\w\s\.]+);.*\)',
    1
)

data_cleaned = data_cleaned.withColumn('platform', platform_getter)

In [24]:
# Bring through the state code as a field
def extract_state(location):
    '''Extract the state code from the location. If multiple state codes are
    given then just take the last one. This seems to happen when the location
    falls within a metropolitan area which crosses state lines.
    
    Sample locations:
    Philadelphia-Camden-Wilmington, PA-NJ-DE-MD
    Dallas-Fort Worth-Arlington, TX
    '''

    if not isinstance(location, str):
        return None
    state = location.split(',')[-1]
    state = state.split('-')[-1].strip()
    return state

state_getter = ssf.udf(extract_state, sst.StringType())
data_cleaned = data_cleaned.withColumn('state', state_getter('location'))

In [25]:
# Tidy up song names

def extract_song(song):
    '''Standardize song title cases and remove any extra tags.
    
    Example:
    Song Title [feat. artist] (Album Version) -->  SONG TITLE'''

    if not isinstance(song, str):
        return None

    # Remove any bracketed tags
    song = re.sub(r' \[.+\]', '', song)
    song = re.sub(r' \(.+\)', '', song)

    # Remove any non-standard characters
    song = re.sub(r'[^\w\s]+', '', song)

    # Fix any duplicated spaces
    song = re.sub(r'\s\s+', ' ', song)

    # Standardize case, remove trailing whitespace
    song = song.strip().upper()
    return song

song_getter = ssf.udf(extract_song, sst.StringType())
data_cleaned = data_cleaned.withColumn('songCleaned', song_getter('song'))

## EDA
Please note that between this section and "Feature Engineering", the majority of tables defined will not be reused. They are used for the sole purpose of generating summary charts/tables.

#### Define Churn

In [26]:
# Indicate the exact moment a customer churns
data_cleaned = data_cleaned.withColumn(
    'churnFlag',
    ssf.when(data_cleaned['page'] == 'Cancellation Confirmation', 1).otherwise(0)
).withColumn(
    'downgradeFlag',
    ssf.when(data_cleaned['page'] == 'Submit Downgrade', 1).otherwise(0)
)

In [27]:
# Flag all records corresponding to a churned customer
churned = data_cleaned.where(data_cleaned['churnFlag'] == 1).select('userId').distinct()
churned = churned.toPandas()['userId'].tolist() # Potentially risky for truly "big" data
data_cleaned = data_cleaned.withColumn(
    'userChurnFlag',
    ssf.when(data_cleaned['userId'].isin(*churned), 1).otherwise(0))

data_cleaned = data_cleaned.persist()

#### Investigate data

In [28]:
# Simple Counts
user_stats = data_cleaned.groupby('userChurnFlag', 'userId').agg(
    #pylint: disable=no-member
    ssf.countDistinct('artist').alias('noArtists'),
    ssf.countDistinct('songCleaned').alias('noSongs'),
    ssf.count('songCleaned').alias('noPlays'),
    ssf.max('gender').alias('gender'),
    ssf.max('state').alias('state'),
    ssf.max('registration').alias('registration'),
    ssf.mean('length').alias('meanSongLength')
)
user_stats = user_stats.toPandas()
user_stats.loc[:, 'registration'] = user_stats['registration'].map(
    lambda x: pd.Timestamp(x, unit='ms'))

In [29]:
# Items in Session
item_stats = data_cleaned.groupby('userChurnFlag', 'sessionId').agg(
    #pylint: disable=no-member
    ssf.max('itemInSession').alias('sessionLength'),
    ssf.max('platform').alias('platform')
)
item_stats = item_stats.toPandas()

In [30]:
# Page visits
page_stats = data_cleaned.groupby('userChurnFlag', 'page').count()
page_stats = page_stats.toPandas()

In [31]:
# Status codes
status_stats = data_cleaned.groupby('userChurnFlag', 'status').count()
status_stats = status_stats.toPandas()

#### Generate Plots

In [34]:
# Simple counts (continuous)
id_cols = {'userChurnFlag', 'userId'}
cat_cols = {'gender', 'state'}
for user_col in [x for x in user_stats.columns if x not in id_cols.union(cat_cols)]:
    fig = px.histogram(
        user_stats,
        x=user_col,
        color='userChurnFlag',
        barmode='overlay',
        histnorm='percent',
        labels={0: 'No Churn', 1: 'Churn'},
        nbins=50
        )
    fig.write_html(f'figures/{user_col}StatsPerc.html')

    fig = px.histogram(
        user_stats,
        x=user_col,
        color='userChurnFlag',
        barmode='overlay',
        labels={0: 'No Churn', 1: 'Churn'},
        nbins=50)
    fig.write_html(f'figures/{user_col}StatsAbs.html')

In [35]:
# Simple counts (discrete)
group_counts = user_stats.groupby('userChurnFlag').agg(
    totalUsers = ('userId', 'nunique')
).reset_index()
for user_col in cat_cols:
    plot_df = user_stats.groupby(['userChurnFlag', user_col]).agg(
        noUsers = ('userId', 'nunique')
    ).reset_index()
    plot_df = plot_df.merge(group_counts, on='userChurnFlag', how='inner')
    plot_df.loc[:, 'percUsers'] = plot_df['noUsers'] / plot_df['totalUsers']
    plot_df = plot_df.sort_values(by='noUsers', ascending=False)
    plot_df.loc[:, 'userChurnFlag'] = plot_df['userChurnFlag'].astype(bool)

    fig = px.bar(
        plot_df,
        x=user_col,
        y='percUsers',
        color='userChurnFlag',
        barmode='group',
        labels={True: 'Churn', False: 'No Churn'}
    )
    fig.write_html(f'figures/{user_col}StatsPerc.html')

    fig = px.bar(
        plot_df,
        x=user_col,
        y='noUsers',
        color='userChurnFlag',
        barmode='group',
        labels={True: 'Churn', False: 'No Churn'}
    )
    fig.write_html(f'figures/{user_col}StatsAbs.html')

In [36]:
# Session length
fig = px.histogram(
    item_stats,
    x='sessionLength',
    color='userChurnFlag',
    barmode='overlay',
    histnorm='percent',
    labels={0: 'No Churn', 1: 'Churn'},
    nbins=50)
fig.write_html('figures/sessionLengthStatsPerc.html')

fig = px.histogram(
    item_stats,
    x='sessionLength',
    color='userChurnFlag',
    barmode='overlay',
    labels={0: 'No Churn', 1: 'Churn'},
    nbins=50)
fig.write_html('figures/sessionLengthStatsAbs.html')

In [37]:
# Platform
fig = px.histogram(
    item_stats,
    x='platform',
    color='userChurnFlag',
    barmode='group',
    histnorm='percent',
    labels={0: 'No Churn', 1: 'Churn'}
)
fig.write_html('figures/platformStatsPerc.html')

fig = px.histogram(
    item_stats,
    x='platform',
    color='userChurnFlag',
    barmode='group',
    labels={0: 'No Churn', 1: 'Churn'}
)
fig.write_html('figures/platformStatsAbs.html')

In [38]:
# Page visits
page_stats.loc[:, 'userChurnFlag'] = page_stats['userChurnFlag'].astype(bool)
group_sums = page_stats.groupby('userChurnFlag').agg(
    totalVisits = ('count', 'sum')
).reset_index()
page_stats = page_stats.merge(group_sums, on='userChurnFlag', how='inner')
page_stats.loc[:, 'percentage'] = page_stats['count'] / page_stats['totalVisits']
page_stats = page_stats.sort_values(by='count', ascending=False)

fig = px.bar(
    page_stats,
    x='page',
    y='count',
    color='userChurnFlag',
    barmode='group',
)
fig.write_html('figures/pageStatsAbs.html')

fig = px.bar(
    page_stats,
    x='page',
    y='percentage',
    color='userChurnFlag',
    barmode='group',
)
fig.write_html('figures/pageStatsPerc.html')

In [39]:
# Status codes
status_stats.loc[:, 'userChurnFlag'] = status_stats['userChurnFlag'].astype(bool)
group_sums = status_stats.groupby('userChurnFlag').agg(
    totalResponses = ('count', 'sum')
).reset_index()
status_stats = status_stats.merge(group_sums, on='userChurnFlag', how='inner')
status_stats.loc[:, 'percentage'] = status_stats['count'] / status_stats['totalResponses']
status_stats = status_stats.sort_values(by='count', ascending=False)

fig = px.bar(
    status_stats,
    x='status',
    y='count',
    color='userChurnFlag',
    barmode='group',
)
fig.write_html('figures/statusStatsAbs.html')

fig = px.bar(
    status_stats,
    x='status',
    y='percentage',
    color='userChurnFlag',
    barmode='group',
)
fig.write_html('figures/statusStatsPerc.html')

## Feature Engineering

In [40]:
# Trigger evaluation of the datset up to this point
data_cleaned = data_cleaned.orderBy(
    'userID', 'ts', ascending=True
)

#### Static variables (no time dependency)

In [41]:
# Calculate these measures using the full range of data
static_vars = data_cleaned.groupby('userId').agg(
    #pylint: disable=no-member
    ssf.max('gender').alias('gender'),
    ssf.max('ts').alias('lastTs'),
    ssf.min('registration').alias('registration'),
    ssf.max('state').alias('state'),
    ssf.max('churnFlag').alias('userChurnFlag')
)

In [42]:
# Calculate the current age of each account
static_vars = static_vars.withColumn(
    'accountAge',
    static_vars['lastTs']-static_vars['registration'])

static_vars = static_vars.drop('lastTs', 'registration')

#### Dynamic Variables (to be evaluated over multiple time windows)

In [43]:
def get_modal_system(data_in, name):
    '''For the provided subset of data, work out the most frequently used
    system for each userId. Rename this derived "platform" column, appending it
    with the provided value for "name".'''

    sys_counts = data_in.groupby('userId', 'platform').agg(
        ssf.countDistinct('sessionId').alias('noSessions'))

    # No built-in mode function, need to manually rank using window functions
    # Using row_number rather than rank to avoid ties
    usr_window = ssw.Window.partitionBy(
        sys_counts['userId']
    ).orderBy(
        sys_counts['noSessions'].desc()
    )
    sys_counts = sys_counts.withColumn(
        #pylint: disable=no-member
        'platformRank', ssf.row_number().over(usr_window)
    )

    # Take first ranked platform for each 
    sys_modal = sys_counts.where(
        sys_counts['platformRank'] == 1
    ).select(
        'userId', f'platform'
    )

    sys_modal = sys_modal.withColumnRenamed('platform', f'platform{name}')

    return sys_modal

In [44]:
def get_simple_aggregates(data_in, name):
    '''For the provided subset of data, derive all features which can be calculated
    using the built-in pyspark aggregation functions. Append all derived field names
    with the provided value for "name".'''
    
    # Integer flag for http errors returned
    flagged = data_in.withColumn(
        'httpError',
        ssf.when(data_in['status']==404, 1).otherwise(0)
    )
    
    # Perform various aggregations at a userId level
    dynamic_vars = flagged.groupby(
        'userId'
    ).agg(
        #pylint: disable=no-member
        ssf.count('platform').alias(f'noPlatforms{name}'),
        ssf.count('userAgent').alias(f'noSystems{name}'),
        ssf.sum('httpError').alias(f'httpErrors{name}'),
        ssf.countDistinct('artist').alias(f'noArtists{name}'),
        ssf.countDistinct('songCleaned').alias(f'noSongs{name}'),
        # Count excludes nulls, so this should be equivalent to counting
        # number of nextSong pages
        ssf.count('songCleaned').alias(f'noPlays{name}'),
        ssf.first('level').alias(f'levelStart{name}'),
        ssf.last('level').alias(f'levelEnd{name}')
    )

    return dynamic_vars

In [45]:
def get_session_stats(data_in, name):
    '''For the provided subset of data, derive some more complex features relating
    to session/song length. Append all derived field names with the provided value
    for "name".'''

    # Limit to song plays
    session_vars = data_in.filter(data_in['page'] == 'NextSong')

    # For each song, bring through the start of the next song
    session_window = ssw.Window.partitionBy(
        session_vars['userId'], session_vars['sessionId']
    ).orderBy(
        'ts'
    )
    session_vars = session_vars.withColumn(
        'nextPlayStart',
        ssf.lead(session_vars['ts']).over(session_window)
    )

    # Use this data to calculate the play time for each song
    session_vars = session_vars.withColumn(
        'playTime',
        session_vars['nextPlayStart'] - session_vars['ts']
    )

    # Get total play time for each session
    session_vars = session_vars.groupby(
        'userId', 'sessionId'
    ).agg(
        #pylint: disable=no-member
        ssf.sum('playTime').alias('sessionLength')
    )

    # Get play time stats for each user
    session_vars = session_vars.groupby(
        'userId'
    ).agg(
        #pylint: disable=no-member
        ssf.mean('sessionLength').alias(f'lengthMean{name}'),
        ssf.stddev('sessionLength').alias(f'lengthStd{name}'),
        ssf.sum('sessionLength').alias(f'lengthSum{name}')
    )

    return session_vars

In [46]:
def get_popularity_scores(data_in, name):
    '''Attempts to create a metric showing how popular the music each user listens
    to is. This metric is calculated for the provided subset of data, and the
    generated popularityScore field is appended with the provided value of "name".'''
    
    # Select song plays only
    song_data = data_in.filter(data_in['page'] == 'NextSong')

    # For each song, calculate the total number of plays
    song_totals = song_data.groupby('songCleaned').agg(
        #pylint: disable=no-member
        ssf.count('ts').alias('songTotal')
    )
    
    # Calculate the number of plays across all songs
    overall_total = song_totals.agg(
        #pylint: disable=no-member
        ssf.sum('songTotal').alias('overallTotal')
    ).collect()[0]['overallTotal']

    # For each song, how much does it contribute towards the total play count?
    song_totals = song_totals.withColumn(
        'overallPerc',
        song_totals['songTotal'] / overall_total
    ).select(
        'songCleaned', 'overallPerc'
    )

    # Restrict input dataset to only relevant fields
    popularity = song_data.select('userId', 'songCleaned')

    # Make a clone of popularity, bug in pyspark causes errors if this isn't done
    # https://stackoverflow.com/questions/45713290/how-to-resolve-the-analysisexception-resolved-attributes-in-spark
    popularity = spark.createDataFrame(popularity.rdd, popularity.schema)

    # Join on popularity scores for each song listened to
    popularity = popularity.join(
        song_totals,
        popularity['songCleaned']==song_totals['songCleaned'],
        how='left')
    
    # For each user, how 'popular' is the music they listen to?
    # As popularity was not created using SELECT DISTINCT this metric
    # is weighted according to the number of plays for each song.
    popularity = popularity.groupby('userId').agg(
        #pylint: disable=no-member
        ssf.sum('overallPerc').alias(f'popularityScore{name}')
    )

    return popularity

In [47]:
def get_page_clicks(data_in, name):
    '''Calculate the number of times each user has clicked through to each page in
    the provided subset of data. All generated field names are appended with the
    provided value of "name".
    
    This function is not currently used, as I decided to avoid the use of the
    pivot operation while debugging memory issues.'''

    def format_page(page):
        '''Remove spaces, append value of name to each page to ensure
        it appears in column headers after the pivot operation.'''

        if not isinstance(page, str):
            return None
        
        page = page.replace(' ', '')
        page = f'{page}{name}'
        return page
    
    page_getter = ssf.udf(format_page, sst.StringType())

    # Process all page names
    page_clicks = data_in.withColumn('page', page_getter('page'))

    # Perform the pivot/aggregation
    page_clicks = page_clicks.groupby('userId').pivot('page').count()

    # List of all pages which should be accounted for
    included_pages = [
        'About', 'Add Friend', 'Add to Playlist', 'Downgrade',
        'Error', 'Help', 'Home', 'Login', 'Logout', 'NextSong', 'Register',
        'Roll Advert', 'Save Settings', 'Settings', 'Submit Downgrade', 'Submit Registration',
        'Submit Upgrade', 'Thumbs Down', 'Thumbs Up', 'Upgrade'
    ]

    # Generate final list of field names
    included_pages = ['userId'] + [format_page(x) for x in included_pages]

    # And which ones didn't
    missing_pages = [x for x in included_pages if x not in page_clicks.columns]

    # Add a column of 0s for any missing pages
    for missing_page in missing_pages:
        page_clicks = page_clicks.withColumn(
            #pylint: disable=no-member
            missing_page,
            ssf.lit(0)
        )

    # Ensure the output always has consistent headers
    page_clicks = page_clicks.select(*included_pages)

    return page_clicks

In [48]:
def get_page_clicks_alt(data_in, name):
    '''Calculate the number of times each user has clicked through to each page in
    the provided subset of data. All generated field names are appended with the
    provided value of "name".
    
    Alternate form of get_page_clicks, avoids use of the pivot function to
    (hopefully) allow the code to run when using the full dataset. In theory the
    output of this function should be the same as the original get_page_clicks.'''

    # List of all pages which should be accounted for
    # - Note that cancellation submission is not included, as this would result
    #   in a proxy column that contains the target label
    included_pages = [
        'About', 'Add Friend', 'Add to Playlist', 'Downgrade',
        'Error', 'Help', 'Home', 'Login', 'Logout', 'NextSong', 'Register',
        'Roll Advert', 'Save Settings', 'Settings', 'Submit Downgrade', 'Submit Registration',
        'Submit Upgrade', 'Thumbs Down', 'Thumbs Up', 'Upgrade'
    ]

    def format_page(page):
        '''Remove spaces, append value of name to each page to ensure
        it appears in column headers after the pivot operation.'''
        if not isinstance(page, str):
            return None
        
        page = page.replace(' ', '')
        page = f'{page}{name}'
        return page
    
    page_getter = ssf.udf(format_page, sst.StringType())

    # Process all page names in the data
    page_clicks = data_in.withColumn('page', page_getter('page'))
    
    # Apply the same transformation to the list of expected page names
    included_pages = [format_page(x) for x in included_pages]

    # Select only data pertaining to pages visited by a user
    page_clicks = page_clicks.select('userId', 'page').dropna(how='any')

    # Add new field for count of each page
    for included_page in included_pages:
        page_clicks = page_clicks.withColumn(
            included_page,
            ssf.when(page_clicks['page'] == included_page, 1).otherwise(0)
        )
    
    # Set up list of aggregations required to calculate the number
    # of visits to each page
    aggregations = [
        #pylint: disable=no-member
        ssf.sum(col).alias(col)
        for col in included_pages
    ]

    # Apply this list of aggregations for each userId
    page_clicks = page_clicks.groupby('userId').agg(*aggregations)

    return page_clicks

#### Generate aggregate statistics

In [49]:
# Get the max timestamp
limits = data_cleaned.agg(
    #pylint: disable=no-member
    ssf.max('ts').alias('maxTs'),
    ssf.min('ts').alias('minTs')
).collect()[0]
max_ts = limits['maxTs']
min_ts = limits['minTs']

In [50]:
# Factors to convert to/from timestamps and days/weeks/months
day_delta = 24 * 60 * 60 * 1000
week_delta = 7 * day_delta
month_delta = 31 * day_delta

In [51]:
# How many months of data do we have?
(max_ts - min_ts) / month_delta

1.9677423088410992

In [52]:
# Create a dictionary of date ranges over which all 'dynamic'
# measures will be evaluated. The keys are simply aliases which
# will be passed to the "name" parameter of the functions defined above.
# This ensures measures have distinct field names between iterations.
start_times = {
    'Week': max_ts - week_delta,
    'Month': max_ts - month_delta
}

# More data is available in the full dataset
if use_full_dataset:
    start_times['TwoMonth'] = max_ts - 2*month_delta

In [53]:
# Calculate these 'dynamic' measures over each of the specified date ranges

# Container for the output from each iteration
outputs = None

for name, start_ts in start_times.items():

    # Select only data from after the start time
    data_subset = data_cleaned.where(
        data_cleaned['ts'] >= start_ts
    ).persist()

    # Calculate all measures using functions defined above
    modal_system = get_modal_system(data_subset, name)
    simple_aggregates = get_simple_aggregates(data_subset, name)
    session_stats = get_session_stats(data_subset, name)
    popularity_scores = get_popularity_scores(data_subset, name)
    page_clicks = get_page_clicks_alt(data_subset, name)

    # Join all of the generated datasets together
    merged = modal_system.join(
        simple_aggregates,
        on='userId',
        how='full_outer'
    ).join(
        session_stats,
        on='userId',
        how='full_outer'
    ).join(
        popularity_scores,
        on='userId',
        how='full_outer'
    ).join(
        page_clicks,
        on='userId',
        how='full_outer'
    ).persist()

    # Add the joined dataset to the final output
    if outputs is None:
        outputs = merged
    else:
        outputs = outputs.join(merged, on='userId', how='full_outer')

In [54]:
# Finally, join the static measures back in
merged = static_vars.join(
    outputs,
    on='userId',
    how='full_outer'
)

# # Fill in empty platform data with 'unknown'
# platform_cols = [x for x in outputs.columns if 'platform' in x]
# merged = merged.fillna('unknown', subset=platform_cols)

# # Fill in empty level data with 'free'
# level_cols = [x for x in outputs.columns if 'level' in x]
# merged = merged.fillna('free', subset=level_cols)

# # Fill in any other empty values with 0
# merged = merged.fillna(0)

#### One-hot Encoding

In [55]:
def fill_empty_string(string_in, fill_value='unknown'):
    '''A simple function which fills in missing values in a categorical field.'''

    if not isinstance(string_in, str):
        return fill_value
    elif not string_in:
        return fill_value
    else:
        return string_in

na_handler = ssf.udf(fill_empty_string, sst.StringType())

In [56]:
# Replace categorical values with numerical IDs for each category

# Get a list of categorical columns in the dataset
# - Needed as some of these will appear in multiple columns (one for each named period)
cat_cols = ['gender', 'state', 'platform', 'level']
cat_cols = [x for x in merged.columns if any((y for y in cat_cols if y in x))]

indexers = {}
for cat_col in cat_cols:
    # Fill in any null values
    merged = merged.withColumn(cat_col, na_handler(cat_col))
    
    # Train a string indexer which generates mappings for each category to
    # a corresponding numerical ID
    indexer = smf.StringIndexer(
        inputCol=cat_col,
        outputCol=f"{cat_col}Inx"
    )
    indexer = indexer.fit(merged)
    
    # Apply the indexer to the dataset, drop the original column
    merged = indexer.transform(merged)
    merged = merged.drop(cat_col).withColumnRenamed(f"{cat_col}Inx", cat_col)
    
    # Save the indexer object for future reference
    indexers[cat_col] = indexer

In [57]:
# Fill in any remaining null values, as they must correspond to numerical features
merged = merged.fillna(0)

In [58]:
# Perform one-hot encoding on all categorical features
encoder = smf.OneHotEncoder(
    inputCols=cat_cols,
    outputCols=[f"{x}Vec" for x in cat_cols]
)
encoder = encoder.fit(merged)
encoded = encoder.transform(merged)
encoded = encoded.drop(*cat_cols)

encoded = encoded.persist()

In [59]:
# Use VectorAssembler to combine all features into a single vector
feature_cols = [x for x in encoded.columns if x not in {'userId', 'userChurnFlag'}]
assembler = smf.VectorAssembler(
    inputCols=feature_cols,
    outputCol='features')
encoded = assembler.transform(encoded)
encoded = encoded.drop(*feature_cols)
encoded = encoded.withColumnRenamed('userChurnFlag', 'label')
encoded = encoded.persist()

# Check the format of the final dataset
encoded_sample = encoded.limit(5).toPandas()
encoded_sample

Unnamed: 0,userId,label,features
0,1000280,0,"(2963223000.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0..."
1,1002185,0,"(5680891000.0, 361.0, 361.0, 1.0, 266.0, 290.0..."
2,1030587,0,"(11370026000.0, 365.0, 365.0, 0.0, 270.0, 292...."
3,1033297,0,"(10034889000.0, 204.0, 204.0, 0.0, 146.0, 157...."
4,1057724,0,"(8298249000.0, 613.0, 613.0, 0.0, 417.0, 479.0..."


## Modeling

#### Set up Pipeline

In [60]:
# Split out validation dataset
train, val = encoded.randomSplit([3.0, 1.0], seed=42)

In [61]:
# Set up scaler for numerical features
scaler = smf.StandardScaler(
    withStd=True,
    withMean=False,
    inputCol='features',
    outputCol='scaledFeatures')

In [62]:
# Use PCA to reduce dimensionality of scaled vectors
reducer = smf.PCA(
    k=10,
    inputCol=scaler.getOutputCol(),
    outputCol='selectedFeatures')

In [63]:
# Use a classifier to generate the final predictions
classifier = smc.GBTClassifier(
    labelCol='label',
    featuresCol=reducer.getOutputCol(),
    predictionCol='predictedLabel'
)

In [64]:
# Combine all steps in a pipeline
pipeline = sm.Pipeline(
    stages=[scaler, reducer, classifier]
)

In [65]:
# Create an evaluator which will quantify model performance
eval_f1 = sme.MulticlassClassificationEvaluator(
    labelCol='label',
    predictionCol='predictedLabel',
    metricName='f1'
)

In [66]:
# Set up a parameter grid for cross validation
param_grid = smt.ParamGridBuilder().addGrid(
    # Number of features to generate
    reducer.k, [25, 50, 75, 100]
).addGrid(
    classifier.maxDepth, [3, 5, 10]
).addGrid(
    classifier.subsamplingRate, [0.1, 0.2, 0.3]
).build()

In [67]:
# Bring everything together
validator = smt.CrossValidator(
    estimator=pipeline,
    estimatorParamMaps=param_grid,
    evaluator=eval_f1,
    numFolds=3
)

#### Fit the Model to the Data

In [68]:
# Fit to data
model = validator.fit(train)

In [70]:
# Save the trained model
model.save('final_model.bin')

In [71]:
train_predictions = model.transform(train)
val_predictions = model.transform(val)

#### Evaluate Model Performance

In [72]:
# Get the selected model parameters
model.getEstimatorParamMaps()[np.argmax(model.avgMetrics)]

{Param(parent='PCA_ffcc51f04dd9', name='k', doc='the number of principal components'): 25,
 Param(parent='GBTClassifier_939da222e149', name='maxDepth', doc='Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.'): 5,
 Param(parent='GBTClassifier_939da222e149', name='subsamplingRate', doc='Fraction of the training data used for learning each decision tree, in range (0, 1].'): 0.3}

In [73]:
train_f1 = eval_f1.evaluate(train_predictions)
val_f1 = eval_f1.evaluate(val_predictions)
print(f"Train F1: {train_f1:,.2f}\nValidation F1: {val_f1:,.2f}")

Train F1: 0.82
Validation F1: 0.80


In [98]:
eval_roc = sme.BinaryClassificationEvaluator(
    labelCol='label',
    rawPredictionCol='probability',
    metricName='areaUnderROC'
)

train_auc = eval_roc.evaluate(train_predictions)
val_auc = eval_roc.evaluate(val_predictions)
print(f"Train AUC: {train_auc:,.2f}\nValidation AUC: {val_auc:,.2f}")

Train AUC: 0.85
Validation AUC: 0.81


In [75]:
eval_accuracy = sme.MulticlassClassificationEvaluator(
    labelCol='label',
    predictionCol='predictedLabel',
    metricName='accuracy'
)

train_acc = eval_accuracy.evaluate(train_predictions)
val_acc = eval_accuracy.evaluate(val_predictions)
print(f"Train Accuracy: {train_acc:,.2f}\nValidation Accuracy: {val_acc:,.2f}")

Train Accuracy: 0.84
Validation Accuracy: 0.81


In [76]:
eval_precision = sme.MulticlassClassificationEvaluator(
    labelCol='label',
    predictionCol='predictedLabel',
    metricName='weightedPrecision'
)

train_prec = eval_precision.evaluate(train_predictions)
val_prec = eval_precision.evaluate(val_predictions)
print(f"Train Precision: {train_prec:,.2f}\nValidation Precision: {val_prec:,.2f}")

Train Precision: 0.82
Validation Precision: 0.80


In [77]:
eval_recall = sme.MulticlassClassificationEvaluator(
    labelCol='label',
    predictionCol='predictedLabel',
    metricName='weightedRecall'
)

train_rec = eval_recall.evaluate(train_predictions)
val_rec = eval_recall.evaluate(val_predictions)
print(f"Train Recall: {train_rec:,.2f}\nValidation Recall: {val_rec:,.2f}")

Train Recall: 0.84
Validation Recall: 0.81


In [78]:
# Generate confusion matrices for each dataset
trn_res_df = train_predictions.select('label', 'predictedLabel').toPandas()
trn_cmat = confusion_matrix(
    y_true=trn_res_df['label'],
    y_pred=trn_res_df['predictedLabel'])


val_res_df = val_predictions.select('label', 'predictedLabel').toPandas()
val_cmat = confusion_matrix(
    y_true=val_res_df['label'],
    y_pred=val_res_df['predictedLabel'])

In [86]:
# Plot the training confusion matrix
fig = ff.create_annotated_heatmap(
    trn_cmat,
    colorscale='haline',
    x=['No Churn', 'Churn'],
    y=['No Churn', 'Churn'])
fig.update_xaxes(side='bottom', title='Predicted Label')
fig.update_yaxes(title='Actual Label')
fig.update_layout(title='Confusion Matrix - Training Dataset')
fig.write_html('figures/cmat_train.html')

In [87]:
# Plot the validation confusion matrix
fig = ff.create_annotated_heatmap(
    val_cmat,
    colorscale='haline',
    x=['No Churn', 'Churn'],
    y=['No Churn', 'Churn'])
fig.update_xaxes(side='bottom', title='Predicted Label')
fig.update_yaxes(title='Actual Label')
fig.update_layout(title='Confusion Matrix - Validation Dataset')
fig.write_html('figures/cmat_val.html')

#### Can we do better by working out the optimal cutoff point

In [141]:
from sklearn.metrics import roc_curve, precision_score, accuracy_score, recall_score
import plotly.graph_objects as go

In [102]:
trn_pred_df = train_predictions.select('userId', 'probability', 'label', 'predictedLabel').toPandas()
val_pred_df = val_predictions.select('userId', 'probability', 'label', 'predictedLabel').toPandas()

trn_pred_df.to_csv('data/train_predictions.csv', index=False)
val_pred_df.to_csv('data/val_predictions.csv', index=False)

In [103]:
val_pred_df.loc[:, 'churnProba'] = val_pred_df['probability'].map(lambda x: x[-1])
trn_pred_df.loc[:, 'churnProba'] = trn_pred_df['probability'].map(lambda x: x[-1])

In [112]:
trn_fpr, trn_tpr, trn_thresholds = roc_curve(trn_pred_df['label'], trn_pred_df['churnProba'])
val_fpr, val_tpr, val_thresholds = roc_curve(val_pred_df['label'], val_pred_df['churnProba'])

In [144]:
# Plot the ROC curve
trn_trace = go.Scatter(
    x=trn_fpr,
    y=trn_tpr,
    text=trn_thresholds,
    name='Training',
    mode='lines',
#     fill='tozeroy'
)

val_trace = go.Scatter(
    x=val_fpr,
    y=val_tpr,
    text=val_thresholds,
    name='Validation',
    mode='lines',
#     fill='tozeroy'
)

ns_trace = go.Scatter(
    x=[0,1],
    y=[0,1],
    name='No Skill',
    mode='lines',
#     fill='tozeroy'
)

layout = go.Layout(
    title='ROC Curve for Final Model',
    xaxis_range=[0,1],
    yaxis_range=[0,1],
    xaxis_title='False Positive Rate',
    yaxis_title='True Positive Rate',
    xaxis_tickformat='%',
    yaxis_tickformat='%',
    yaxis_scaleanchor='x',
    yaxis_scaleratio=1
)

figure = go.Figure(
    data=[trn_trace, val_trace, ns_trace],
    layout=layout
)

figure.write_html('figures/roc_curve.html')

In [127]:
# What does the confusion matrix look like if we set the threshold to 16% (for a tpr of 80%? on the validation dataset)
trn_pred_df.loc[:, 'predictedLabelNew'] = trn_pred_df['churnProba'].map(lambda x: 1 if x >= 0.16 else 0)
val_pred_df.loc[:, 'predictedLabelNew'] = val_pred_df['churnProba'].map(lambda x: 1 if x >= 0.16 else 0)

In [128]:
# Generate updated confusion matrices
trn_cmat = confusion_matrix(
    y_true=trn_pred_df['label'],
    y_pred=trn_pred_df['predictedLabelNew'])

val_cmat = confusion_matrix(
    y_true=val_pred_df['label'],
    y_pred=val_pred_df['predictedLabelNew'])

In [135]:
# Plot the training confusion matrix
fig = ff.create_annotated_heatmap(
    trn_cmat,
    colorscale='haline',
    x=['No Churn', 'Churn'],
    y=['No Churn', 'Churn'])
fig.update_xaxes(side='bottom', title='Predicted Label')
fig.update_yaxes(title='Actual Label')
fig.update_layout(title='Confusion Matrix - Training Dataset @16%')
fig.write_html('figures/cmat_train_tuned.html')

In [136]:
# Plot the validation confusion matrix
fig = ff.create_annotated_heatmap(
    val_cmat,
    colorscale='haline',
    x=['No Churn', 'Churn'],
    y=['No Churn', 'Churn'])
fig.update_xaxes(side='bottom', title='Predicted Label')
fig.update_yaxes(title='Actual Label')
fig.update_layout(title='Confusion Matrix - Validation Dataset @16%')
fig.write_html('figures/cmat_val_tuned.html')

#### Sense-check, how separable are the two classes?

In [150]:
# Combine datasets & add a text representation of the class for readability
trn_pred_df.loc[:, 'dataset'] = 'Training',
val_pred_df.loc[:, 'dataset'] = 'Validation'
plot_df = pd.concat([trn_pred_df, val_pred_df], ignore_index=True, axis=0)
plot_df.loc[:, 'class'] = plot_df['label'].map(lambda x: 'Churn' if x else 'No Churn')
plot_df.head()

Unnamed: 0,userId,probability,label,predictedLabel,churnProba,dataset,predictedLabelNew,class
0,1000280,"[0.677025680300048, 0.32297431969995205]",0,0.0,0.322974,Training,1,No Churn
1,1002185,"[0.927641734964618, 0.07235826503538201]",0,0.0,0.072358,Training,0,No Churn
2,1033297,"[0.9448878101412581, 0.05511218985874189]",0,0.0,0.055112,Training,0,No Churn
3,1057724,"[0.9013045896844434, 0.09869541031555662]",0,0.0,0.098695,Training,0,No Churn
4,1059049,"[0.8884397826922436, 0.11156021730775645]",0,0.0,0.11156,Training,0,No Churn


In [153]:
# Create a distribution plot to visualize churnProba between classes
fig = px.histogram(
    plot_df,
    x='churnProba',
    color='class',
    histnorm='percent',
    marginal='box',
    title='Final Model - Class Separability',
    barmode='overlay'
)
fig.write_html('figures/class_separability.html')