In [1]:
# Run this cell to import pyspark and to define start_spark() and stop_spark()

import findspark

findspark.init()

import getpass
import pandas
import pyspark
import random
import re

from IPython.display import display, HTML
from pyspark import SparkContext
from pyspark.sql import SparkSession
import pyspark.sql.functions as F


# Constants used to interact with Azure Blob Storage using the hdfs command or Spark

global username

username = re.sub('@.*', '', getpass.getuser())

global azure_account_name
global azure_data_container_name
global azure_user_container_name
global azure_user_token

azure_account_name = "madsstorage002"
azure_data_container_name = "campus-data"
azure_user_container_name = "campus-user"
azure_user_token = r"sp=racwdl&st=2025-08-01T09:41:33Z&se=2026-12-30T16:56:33Z&spr=https&sv=2024-11-04&sr=c&sig=GzR1hq7EJ0lRHj92oDO1MBNjkc602nrpfB5H8Cl7FFY%3D"


# Functions used below

def dict_to_html(d):
    """Convert a Python dictionary into a two column table for display.
    """

    html = []

    html.append(f'<table width="100%" style="width:100%; font-family: monospace;">')
    for k, v in d.items():
        html.append(f'<tr><td style="text-align:left;">{k}</td><td>{v}</td></tr>')
    html.append(f'</table>')

    return ''.join(html)


def show_as_html(df, n=20):
    """Leverage existing pandas jupyter integration to show a spark dataframe as html.
    
    Args:
        n (int): number of rows to show (default: 20)
    """

    display(df.limit(n).toPandas())

    
def display_spark():
    """Display the status of the active Spark session if one is currently running.
    """
    
    if 'spark' in globals() and 'sc' in globals():

        name = sc.getConf().get("spark.app.name")

        html = [
            f'<p><b>Spark</b></p>',
            f'<p>The spark session is <b><span style="color:green">active</span></b>, look for <code>{name}</code> under the running applications section in the Spark UI.</p>',
            f'<ul>',
            f'<li><a href="http://localhost:{sc.uiWebUrl.split(":")[-1]}" target="_blank">Spark Application UI</a></li>',
            f'</ul>',
            f'<p><b>Config</b></p>',
            dict_to_html(dict(sc.getConf().getAll())),
            f'<p><b>Notes</b></p>',
            f'<ul>',
            f'<li>The spark session <code>spark</code> and spark context <code>sc</code> global variables have been defined by <code>start_spark()</code>.</li>',
            f'<li>Please run <code>stop_spark()</code> before closing the notebook or restarting the kernel or kill <code>{name}</code> by hand using the link in the Spark UI.</li>',
            f'</ul>',
        ]
        display(HTML(''.join(html)))
        
    else:
        
        html = [
            f'<p><b>Spark</b></p>',
            f'<p>The spark session is <b><span style="color:red">stopped</span></b>, confirm that <code>{username} (notebook)</code> is under the completed applications section in the Spark UI.</p>',
            f'<ul>',
            f'<li><a href="http://mathmadslinux2p.canterbury.ac.nz:8080/" target="_blank">Spark UI</a></li>',
            f'</ul>',
        ]
        display(HTML(''.join(html)))


# Functions to start and stop spark

def start_spark(executor_instances=2, executor_cores=1, worker_memory=1, master_memory=1):
    """Start a new Spark session and define globals for SparkSession (spark) and SparkContext (sc).
    
    Args:
        executor_instances (int): number of executors (default: 2)
        executor_cores (int): number of cores per executor (default: 1)
        worker_memory (float): worker memory (default: 1)
        master_memory (float): master memory (default: 1)
    """

    global spark
    global sc

    cores = executor_instances * executor_cores
    partitions = cores * 4
    port = 4000 + random.randint(1, 999)

    spark = (
        SparkSession.builder
        .config("spark.driver.extraJavaOptions", f"-Dderby.system.home=/tmp/{username}/spark/")
        .config("spark.dynamicAllocation.enabled", "false")
        .config("spark.executor.instances", str(executor_instances))
        .config("spark.executor.cores", str(executor_cores))
        .config("spark.cores.max", str(cores))
        .config("spark.driver.memory", f'{master_memory}g')
        .config("spark.executor.memory", f'{worker_memory}g')
        .config("spark.driver.maxResultSize", "0")
        .config("spark.sql.shuffle.partitions", str(partitions))
        .config("spark.kubernetes.container.image", "madsregistry001.azurecr.io/hadoop-spark:v3.3.5-openjdk-8")
        .config("spark.kubernetes.container.image.pullPolicy", "IfNotPresent")
        .config("spark.kubernetes.memoryOverheadFactor", "0.3")
        .config("spark.memory.fraction", "0.1")
        .config(f"fs.azure.sas.{azure_user_container_name}.{azure_account_name}.blob.core.windows.net",  azure_user_token)
        .config("spark.app.name", f"{username} (notebook)")
        .getOrCreate()
    )
    sc = SparkContext.getOrCreate()
    
    display_spark()

    
def stop_spark():
    """Stop the active Spark session and delete globals for SparkSession (spark) and SparkContext (sc).
    """

    global spark
    global sc

    if 'spark' in globals() and 'sc' in globals():

        spark.stop()

        del spark
        del sc

    display_spark()


# Make css changes to improve spark output readability

html = [
    '<style>',
    'pre { white-space: pre !important; }',
    'table.dataframe td { white-space: nowrap !important; }',
    'table.dataframe thead th:first-child, table.dataframe tbody th { display: none; }',
    '</style>',
]
display(HTML(''.join(html)))

In [2]:
# Run this cell to start a spark session in this notebook

start_spark(executor_instances=4, executor_cores=4, worker_memory=8, master_memory=8)

25/10/07 10:00:32 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


0,1
spark.dynamicAllocation.enabled,false
spark.fs.azure.sas.campus-user.madsstorage002.blob.core.windows.net,"""sp=racwdl&st=2025-08-01T09:41:33Z&se=2026-12-30T16:56:33Z&spr=https&sv=2024-11-04&sr=c&sig=GzR1hq7EJ0lRHj92oDO1MBNjkc602nrpfB5H8Cl7FFY%3D"""
spark.kubernetes.driver.pod.name,spark-master-driver
spark.executor.instances,4
spark.app.name,rsh224 (notebook)
spark.cores.max,16
spark.kubernetes.container.image.pullPolicy,IfNotPresent
spark.kubernetes.namespace,rsh224
spark.executor.cores,4
spark.kubernetes.executor.podNamePrefix,rsh224-notebook-d58c6299bb5378c3


In [None]:
# Write your imports here or insert cells below

from pyspark.sql import functions as F
from pyspark.sql.types import *
from functools import reduce
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

import sys
import os

sys.path.append(os.path.abspath(".."))

from helpers import load_feature

In [5]:
directory_path = f'wasbs://{azure_data_container_name}@{azure_account_name}.blob.core.windows.net/msd'

In [6]:
features_path = f'{username}/msd/output/feature_genre'
input_path = f'wasbs://{azure_user_container_name}@{azure_account_name}.blob.core.windows.net/{features_path}'

In [7]:
features = spark.read.csv(
    input_path,
    inferSchema=True,
    header=True
)

                                                                                

In [8]:
features.show(20, False)

25/10/07 10:01:21 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.


+------------------+------+-------------+-------+--------+--------+-------------+--------------+------------+-------------+-------------+------+-------------+------+-------+-------+------------+-------------+------------+-------------+-------------+------+-----+------+-------+-------+---+------+--------+--------+--------+-------+---------+---------+---------+-------+---+-------+------+------+-------+------+-------+-------+-------+--------+--------+----------+---------+---------+---------+--------+---------+---------+---------+---------+---------+---------+---------+---------+--------+--------+--------+--------+--------+--------+--------+--------+--------+--------+--------+--------+--------+---------+--------+--------+--------+--------+--------+--------+--------+--------+--------+--------+--------+--------+--------+--------+--------+--------+--------+--------+--------+--------+--------+--------+--------+--------+--------+--------+--------+--------+--------+--------+---------+--------+--

In [10]:
features.describe().show()

25/10/07 10:41:34 WARN DAGScheduler: Broadcasting large task binary with size 1648.8 KiB
[Stage 8:>                                                          (0 + 1) / 1]

+-------+------------------+------------------+--------------------+-----------------+------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+------------------+--------------------+------------------+------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+------------------+-----------------+-------------------+--------------------+-------------------+------+-------------------+--------------------+--------------------+-------------------+-------------------+--------------------+--------------------+--------------------+--------------------+------+-------------------+-------------------+-------------------+--------------------+--------------------+--------------------+-------------------+-------------------+-------------------+--------------------+------------------+-------------------+-------------------+----

                                                                                

In [14]:
features_encoded = features.withColumn('is_electronic', F.when(F.col('genre') == 'Electronic', 1).otherwise(0))

In [15]:
features_encoded.show(20, False)

+------------------+------+-------------+-------+--------+--------+-------------+--------------+------------+-------------+-------------+------+-------------+------+-------+-------+------------+-------------+------------+-------------+-------------+------+-----+------+-------+-------+---+------+--------+--------+--------+-------+---------+---------+---------+-------+---+-------+------+------+-------+------+-------+-------+-------+--------+--------+----------+---------+---------+---------+--------+---------+---------+---------+---------+---------+---------+---------+---------+--------+--------+--------+--------+--------+--------+--------+--------+--------+--------+--------+--------+--------+---------+--------+--------+--------+--------+--------+--------+--------+--------+--------+--------+--------+--------+--------+--------+--------+--------+--------+--------+--------+--------+--------+--------+--------+--------+--------+--------+--------+--------+--------+--------+---------+--------+--

In [20]:
class_counts = features_encoded.groupBy('is_electronic').count().orderBy('is_electronic')

class_counts.show()



+-------------+------+
|is_electronic| count|
+-------------+------+
|            0|379938|
|            1| 40662|
+-------------+------+



                                                                                

In [25]:
total = features.count()
class_counts = class_counts.withColumn('proportion', F.round(F.col('count') / total, 3))

class_counts.show()



+-------------+------+----------+
|is_electronic| count|proportion|
+-------------+------+----------+
|            0|379938|     0.903|
|            1| 40662|     0.097|
+-------------+------+----------+



                                                                                

In [37]:
from pyspark.ml.feature import VectorAssembler, StandardScaler
from pyspark.sql.window import Windowis_electronic

# Scaling

In [29]:
feature_cols = [c for c in features_encoded.columns if c not in ('track_id', 'genre', 'is_electronic')]

assembler = VectorAssembler(inputCols=feature_cols, outputCol='features_raw')
assembled_features = assembler.transform(features_encoded)

In [32]:
scaler = StandardScaler(
    inputCol='features_raw',
    outputCol='features',
    withMean=True,
    withStd=True
)

scaler_model = scaler.fit(assembled_features)
scaled_df = scaler_model.transform(assembled_features)

                                                                                

In [35]:
scaled_df = scaled_df.select('track_id', 'features', 'is_electronic')

In [36]:
scaled_df.show(20, False)

+------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

# Stratified Sampling

In [None]:
train_fraction = 0.8

window_spec = Window.partitionBy('is_electronic').orderBy(F.rand())

df_ranked = scaled_df.withColumn('row_number', F.row_number().over(window_spec))

class_counts = df_ranked.groupBy('is_electronic').agg(F.count('*').alias('n')).collect()

cutoffs = {r['is_electronic']: int(r['n'] * train_fraction) for r in class_counts}

df_flagged = df_ranked.withColumn(
    'is_train',
    F.when(
        ((F.col("is_electronic") == 1) & (F.col("row_number") <= cutoffs[1])) |
        ((F.col("is_electronic") == 0) & (F.col("row_number") <= cutoffs[0])),
        True
    ).otherwise(False)
)

train_df = df_flagged.filter('is_train = True').drop('row_number', 'is_train')
test_df = df_flagged.filter('is_train = False').drop('row_number', 'is_train')

In [56]:
train_df.groupBy('is_electronic').count().show()



+-------------+------+
|is_electronic| count|
+-------------+------+
|            0|303950|
|            1| 32529|
+-------------+------+



                                                                                

In [57]:
test_df.groupBy('is_electronic').count().show()



+-------------+-----+
|is_electronic|count|
+-------------+-----+
|            0|75988|
|            1| 8133|
+-------------+-----+



                                                                                

# Model Training and evaluation

In [59]:
from pyspark.ml.classification import LogisticRegression, RandomForestClassifier, GBTClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.ml import Pipeline

In [61]:
lr = LogisticRegression(featuresCol='features', labelCol='is_electronic')
rf = RandomForestClassifier(featuresCol='features', labelCol='is_electronic', numTrees = 50, maxDepth=10)
gbt = GBTClassifier(featuresCol='features', labelCol='is_electronic', maxIter=50, maxDepth=5)

models = {"Logistic Regression": lr, "Random Forest": rf, "GBT": gbt}

In [67]:
def confusion_metrics(predictions):
    tp = predictions.filter("prediction = 1 AND is_electronic = 1").count()
    tn = predictions.filter("prediction = 0 AND is_electronic = 0").count()
    fp = predictions.filter("prediction = 1 AND is_electronic = 0").count()
    fn = predictions.filter("prediction = 0 AND is_electronic = 1").count()
    
    precision = tp / (tp + fp)
    recall = tp / (tp + fn)
    f1 = 2 * precision * recall / (precision + recall)
    accuracy = (tp + tn) / (tp + tn + fp + fn)

    auc_evaluator = BinaryClassificationEvaluator(labelCol='is_electronic', metricName='areaUnderROC')
    auc = auc_evaluator.evaluate(predictions)
    
    return accuracy, precision, recall, f1, auc

In [69]:
results = []

for name, model in models.items():
    fitted = model.fit(train_df)
    predictions = fitted.transform(test_df)

    acc, prec, rec, f1, auc = confusion_metrics(predictions)

    results.append((name, acc, prec, rec, f1, auc))

evaluation = spark.createDataFrame(results, ['Model', 'Accuracy', 'Precision', 'Recall', 'F1', 'AUC'])

25/10/07 12:26:44 WARN DAGScheduler: Broadcasting large task binary with size 1248.1 KiB
25/10/07 12:26:54 WARN DAGScheduler: Broadcasting large task binary with size 2.2 MiB
25/10/07 12:27:08 WARN DAGScheduler: Broadcasting large task binary with size 4.0 MiB
25/10/07 12:27:26 WARN DAGScheduler: Broadcasting large task binary with size 7.0 MiB
25/10/07 12:27:44 WARN DAGScheduler: Broadcasting large task binary with size 1527.3 KiB
25/10/07 12:27:58 WARN DAGScheduler: Broadcasting large task binary with size 5.0 MiB
25/10/07 12:28:03 WARN DAGScheduler: Broadcasting large task binary with size 5.0 MiB
25/10/07 12:28:11 WARN DAGScheduler: Broadcasting large task binary with size 5.0 MiB
25/10/07 12:28:19 WARN DAGScheduler: Broadcasting large task binary with size 5.0 MiB
25/10/07 12:28:23 WARN DAGScheduler: Broadcasting large task binary with size 5.0 MiB
                                                                                

In [70]:
evaluation.show()

+-------------------+------------------+------------------+-------------------+-------------------+------------------+
|              Model|          Accuracy|         Precision|             Recall|                 F1|               AUC|
+-------------------+------------------+------------------+-------------------+-------------------+------------------+
|Logistic Regression|0.9220527573376446|0.6888782358581016| 0.3533751383253412| 0.4671271840715156|0.8796820465825044|
|      Random Forest|0.9188668703415318|0.7503828483920367|0.24099348333948112|0.36482084690553745|0.8732506095479907|
|                GBT|0.9215059259875655|0.6796618130577736| 0.3558342555022747|  0.467113227342426|0.8864095732925558|
+-------------------+------------------+------------------+-------------------+-------------------+------------------+



In [71]:
stop_spark()

25/10/07 13:17:49 WARN ExecutorPodsWatchSnapshotSource: Kubernetes client has been closed.
