---

# Customer Churn Prediction with SparkML
**EPITA – MSc Artificial Intelligence Systems (AIS)**  
**Spark & Python for Big Data AIS S2 F25**

**Students:** 
- TRUONG Kim Tan
- LE Linh Long
- George
- Farouk

---

## Phase 1: Setup and Data Loading

In [68]:
import pandas as pd
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import StructType, StructField, DoubleType, IntegerType, StringType
from pyspark.ml import Pipeline
from pyspark.ml.feature import StringIndexer, OneHotEncoder, VectorAssembler, StandardScaler
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator

### 1.1. Iniialize SparkSession

In [69]:
spark = SparkSession.builder \
    .appName("ChurnPredictionPipeline") \
    .config("spark.sql.shuffle.partitions", "4") \
    .config("spark.driver.memory", "4g") \
    .config("spark.driver.maxResultSize", "2g") \
    .config("spark.executor.memory", "4g") \
    .getOrCreate()

print(spark)

<pyspark.sql.session.SparkSession object at 0x11534dfd0>


### 1.2. Load the Dataset

In [70]:
temp_df = pd.read_csv("WA_Fn-UseC_-Telco-Customer-Churn.csv", nrows=1)
cols = temp_df.columns.tolist()

# build schema
schema = StructType([StructField(c, StringType(), True) for c in cols])

df = spark.read.csv("WA_Fn-UseC_-Telco-Customer-Churn.csv", header=True, schema=schema)

### 1.3. Initial Data Inspection

In [71]:
print(f"Total rows: {df.count()}")
df.printSchema()
df.show()

Total rows: 7043
root
 |-- customerID: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- SeniorCitizen: string (nullable = true)
 |-- Partner: string (nullable = true)
 |-- Dependents: string (nullable = true)
 |-- tenure: string (nullable = true)
 |-- PhoneService: string (nullable = true)
 |-- MultipleLines: string (nullable = true)
 |-- InternetService: string (nullable = true)
 |-- OnlineSecurity: string (nullable = true)
 |-- OnlineBackup: string (nullable = true)
 |-- DeviceProtection: string (nullable = true)
 |-- TechSupport: string (nullable = true)
 |-- StreamingTV: string (nullable = true)
 |-- StreamingMovies: string (nullable = true)
 |-- Contract: string (nullable = true)
 |-- PaperlessBilling: string (nullable = true)
 |-- PaymentMethod: string (nullable = true)
 |-- MonthlyCharges: string (nullable = true)
 |-- TotalCharges: string (nullable = true)
 |-- Churn: string (nullable = true)

+----------+------+-------------+-------+----------+------+-------

---

## Phase 2: Exploratory Data Analysis (EDA) & Data Cleaning

### 2.1. Data Cleaning (Handling Missing Values)

**a) Column count check for Missing Values**

In [72]:
def count_missing(c, dtype):
    if isinstance(dtype, StringType):
        return F.count(
            F.when(
                F.col(c).isNull() | (F.trim(F.col(c)) == ""), 
                c
            )
        ).alias(c)
    else:
        return F.count(F.when(F.col(c).isNull(), c)).alias(c)

missing_counts = df.select([
    count_missing(c, dtype) for c, dtype in df.dtypes
])

missing_data = missing_counts.collect()[0].asDict()

print(f"{'Column':<20} {'Missing Count':>15}")
print("-" * 40)
for col_name, count in missing_data.items():
    print(f"{col_name:<20} {count:>15}")

Column                 Missing Count
----------------------------------------
customerID                         0
gender                             0
SeniorCitizen                      0
Partner                            0
Dependents                         0
tenure                             0
PhoneService                       0
MultipleLines                      0
InternetService                    0
OnlineSecurity                     0
OnlineBackup                       0
DeviceProtection                   0
TechSupport                        0
StreamingTV                        0
StreamingMovies                    0
Contract                           0
PaperlessBilling                   0
PaymentMethod                      0
MonthlyCharges                     0
TotalCharges                       0
Churn                              0


**b) Data Cleaning**

In [73]:
# Handle TotalCharges: replace empty strings with null, then cast to Double
df = df.withColumn('TotalCharges', 
                   F.when(F.col('TotalCharges').isin("", " "), None)
                   .otherwise(F.col('TotalCharges')))

# Drop rows with null TotalCharges (11 records)
print(f"\nRows before dropping null TotalCharges: {df.count()}")
df_clean = df.na.drop(subset=['TotalCharges'])
print(f"Rows after dropping null TotalCharges: {df_clean.count()}")

# Check for duplicates
print(f"\nDuplicate customerID count: {df_clean.count() - df_clean.select('customerID').distinct().count()}")


Rows before dropping null TotalCharges: 7043
Rows after dropping null TotalCharges: 7032

Duplicate customerID count: 0


### 2.2. Data preparation

In [74]:
df_clean = df_clean.withColumn('TotalCharges', F.col('TotalCharges').cast(DoubleType()))

# Cast SeniorCitizen to Integer (here the values are currently "0", "1")
df_clean = df_clean.withColumn('SeniorCitizen', F.col('SeniorCitizen').cast(IntegerType()))

# Cast tenure to Integer
df_clean = df_clean.withColumn('tenure', F.col('tenure').cast(IntegerType()))

# Cast MonthlyCharges to Double
df_clean = df_clean.withColumn('MonthlyCharges', F.col('MonthlyCharges').cast(DoubleType()))

print("Schema after conversion of columns:")
df_clean.printSchema()

Schema after conversion of columns:
root
 |-- customerID: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- SeniorCitizen: integer (nullable = true)
 |-- Partner: string (nullable = true)
 |-- Dependents: string (nullable = true)
 |-- tenure: integer (nullable = true)
 |-- PhoneService: string (nullable = true)
 |-- MultipleLines: string (nullable = true)
 |-- InternetService: string (nullable = true)
 |-- OnlineSecurity: string (nullable = true)
 |-- OnlineBackup: string (nullable = true)
 |-- DeviceProtection: string (nullable = true)
 |-- TechSupport: string (nullable = true)
 |-- StreamingTV: string (nullable = true)
 |-- StreamingMovies: string (nullable = true)
 |-- Contract: string (nullable = true)
 |-- PaperlessBilling: string (nullable = true)
 |-- PaymentMethod: string (nullable = true)
 |-- MonthlyCharges: double (nullable = true)
 |-- TotalCharges: double (nullable = true)
 |-- Churn: string (nullable = true)



### 2.3. Univariate Analysis (Analyzing Single Variables)

**a) Numerical Features**

In [75]:
# numerical columns summary
numerical_cols = ['tenure', 'MonthlyCharges', 'TotalCharges']
df_clean.select(numerical_cols).describe().show()

# SeniorCitizen belongs to categorical column
print("SeniorCitizen Distribution:")
df_clean.groupBy('SeniorCitizen').count().show()

+-------+------------------+------------------+------------------+
|summary|            tenure|    MonthlyCharges|      TotalCharges|
+-------+------------------+------------------+------------------+
|  count|              7032|              7032|              7032|
|   mean|32.421786120591584| 64.79820819112632|2283.3004408418697|
| stddev|24.545259709263245|30.085973884049825| 2266.771361883145|
|    min|                 1|             18.25|              18.8|
|    max|                72|            118.75|            8684.8|
+-------+------------------+------------------+------------------+

SeniorCitizen Distribution:
+-------------+-----+
|SeniorCitizen|count|
+-------------+-----+
|            0| 5890|
|            1| 1142|
+-------------+-----+



**b) Categorical Features**

In [76]:
categorical_cols = [c for c in df_clean.columns if c not in numerical_cols + ['customerID', 'SeniorCitizen']]
print(f"Categorical columns: {categorical_cols}")

# Show value counts for key categorical columns
key_categorical = ['gender', 'Partner', 'Dependents', 'Contract', 
                   'InternetService', 'PaymentMethod', 'Churn']

for col in key_categorical:
    print(f"{col} distribution:")
    df_clean.groupBy(col).count().orderBy(F.desc('count')).show()

Categorical columns: ['gender', 'Partner', 'Dependents', 'PhoneService', 'MultipleLines', 'InternetService', 'OnlineSecurity', 'OnlineBackup', 'DeviceProtection', 'TechSupport', 'StreamingTV', 'StreamingMovies', 'Contract', 'PaperlessBilling', 'PaymentMethod', 'Churn']
gender distribution:
+------+-----+
|gender|count|
+------+-----+
|  Male| 3549|
|Female| 3483|
+------+-----+

Partner distribution:
+-------+-----+
|Partner|count|
+-------+-----+
|     No| 3639|
|    Yes| 3393|
+-------+-----+

Dependents distribution:
+----------+-----+
|Dependents|count|
+----------+-----+
|        No| 4933|
|       Yes| 2099|
+----------+-----+

Contract distribution:
+--------------+-----+
|      Contract|count|
+--------------+-----+
|Month-to-month| 3875|
|      Two year| 1685|
|      One year| 1472|
+--------------+-----+

InternetService distribution:
+---------------+-----+
|InternetService|count|
+---------------+-----+
|    Fiber optic| 3096|
|            DSL| 2416|
|             No| 1520|


### 2.4. Bivariate Analysis (Analyzing Relationships)

In [77]:
print("Churn rate by Contract type:")
df_clean.groupBy('Contract').agg(
    F.count('*').alias('total'),
    F.sum(F.when(F.col('Churn') == 'Yes', 1).otherwise(0)).alias('churned'),
    (F.sum(F.when(F.col('Churn') == 'Yes', 1).otherwise(0)) / F.count('*') * 100).alias('churn_rate')
).show()

print("Churn rate by Internet Service:")
df_clean.groupBy('InternetService').agg(
    F.count('*').alias('total'),
    F.sum(F.when(F.col('Churn') == 'Yes', 1).otherwise(0)).alias('churned'),
    (F.sum(F.when(F.col('Churn') == 'Yes', 1).otherwise(0)) / F.count('*') * 100).alias('churn_rate')
).show()

# Tenure statistics by Churn
print("Tenure statistics by Churn:")
df_clean.groupBy('Churn').agg(
    F.avg('tenure').alias('avg_tenure'),
    F.min('tenure').alias('min_tenure'),
    F.max('tenure').alias('max_tenure'),
    F.avg('MonthlyCharges').alias('avg_monthly_charges'),
    F.avg('TotalCharges').alias('avg_total_charges')
).show()

Churn rate by Contract type:
+--------------+-----+-------+------------------+
|      Contract|total|churned|        churn_rate|
+--------------+-----+-------+------------------+
|Month-to-month| 3875|   1655| 42.70967741935484|
|      One year| 1472|    166|11.277173913043478|
|      Two year| 1685|     48|2.8486646884272995|
+--------------+-----+-------+------------------+

Churn rate by Internet Service:
+---------------+-----+-------+------------------+
|InternetService|total|churned|        churn_rate|
+---------------+-----+-------+------------------+
|             No| 1520|    113| 7.434210526315789|
|            DSL| 2416|    459|18.998344370860927|
|    Fiber optic| 3096|   1297| 41.89276485788114|
+---------------+-----+-------+------------------+

Tenure statistics by Churn:
+-----+------------------+----------+----------+-------------------+------------------+
|Churn|        avg_tenure|min_tenure|max_tenure|avg_monthly_charges| avg_total_charges|
+-----+------------------+

---

## Phase 3: Data Transformation & Feature Engineering

### 3.1. Identify Feature Columns

In [78]:
# we identify feature columns (except for customerID and target)
feature_cols = [c for c in df_clean.columns if c not in ['customerID', 'Churn']]

# we separate columns into categorical and numerical
cat_cols = [c for c in feature_cols if c not in numerical_cols]
num_cols = [c for c in feature_cols if c in numerical_cols]

print(f"Features to use: {feature_cols}")
print(f"Categorical: {cat_cols}")
print(f"Numerical: {num_cols}")

# we convert the target variable to numeric
df_clean = df_clean.withColumn('label', F.when(F.col('Churn') == 'Yes', 1).otherwise(0))

Features to use: ['gender', 'SeniorCitizen', 'Partner', 'Dependents', 'tenure', 'PhoneService', 'MultipleLines', 'InternetService', 'OnlineSecurity', 'OnlineBackup', 'DeviceProtection', 'TechSupport', 'StreamingTV', 'StreamingMovies', 'Contract', 'PaperlessBilling', 'PaymentMethod', 'MonthlyCharges', 'TotalCharges']
Categorical: ['gender', 'SeniorCitizen', 'Partner', 'Dependents', 'PhoneService', 'MultipleLines', 'InternetService', 'OnlineSecurity', 'OnlineBackup', 'DeviceProtection', 'TechSupport', 'StreamingTV', 'StreamingMovies', 'Contract', 'PaperlessBilling', 'PaymentMethod']
Numerical: ['tenure', 'MonthlyCharges', 'TotalCharges']


### 3.2. Define Pipeline Stages

**a) Categorical Encoding**

In [79]:
indexers = []
encoders = []
indexed_cols = []
encoded_cols = []

for col in cat_cols:
    indexer = StringIndexer(inputCol=col, outputCol=f"{col}_indexed", handleInvalid='keep')
    encoder = OneHotEncoder(inputCol=f"{col}_indexed", outputCol=f"{col}_encoded", dropLast=False)
    indexers.append(indexer)
    encoders.append(encoder)
    indexed_cols.append(f"{col}_indexed")
    encoded_cols.append(f"{col}_encoded")

**b) Vector Assembly**

In [80]:
# combine numerical features with encoded categorical features
assembler_inputs = num_cols + encoded_cols
vector_assembler = VectorAssembler(inputCols=assembler_inputs, outputCol='features_unscaled')

**c) Feature Scaling**

In [81]:
scaler = StandardScaler(inputCol='features_unscaled', outputCol='features', 
                        withStd=True, withMean=True)

**d) Create Pipeline**

In [82]:
pipeline_stages = indexers + encoders + [vector_assembler, scaler]
pipeline = Pipeline(stages=pipeline_stages)

print(f"Pipeline stages: {len(indexers)} indexers, {len(encoders)} encoders, 1 assembler, 1 scaler")
print(f"Total stages in pipeline: {len(pipeline_stages)}")

Pipeline stages: 16 indexers, 16 encoders, 1 assembler, 1 scaler
Total stages in pipeline: 34


---

## Phase 4: Linear Regression Classification

### 4.1. Define Logistic Regression & Imports

In [83]:
lr = LogisticRegression(
    featuresCol='features',
    labelCol='label'
)


### 4.2. Build Full Pipeline (Preprocessing + LR)

In [84]:
lr_pipeline = Pipeline(stages=indexers + encoders + [vector_assembler, scaler, lr])

print('LR pipeline created with the following stages:')
for i, stage in enumerate(lr_pipeline.getStages()):
    print(f'  {i+1}. {stage.__class__.__name__}')


LR pipeline created with the following stages:
  1. StringIndexer
  2. StringIndexer
  3. StringIndexer
  4. StringIndexer
  5. StringIndexer
  6. StringIndexer
  7. StringIndexer
  8. StringIndexer
  9. StringIndexer
  10. StringIndexer
  11. StringIndexer
  12. StringIndexer
  13. StringIndexer
  14. StringIndexer
  15. StringIndexer
  16. StringIndexer
  17. OneHotEncoder
  18. OneHotEncoder
  19. OneHotEncoder
  20. OneHotEncoder
  21. OneHotEncoder
  22. OneHotEncoder
  23. OneHotEncoder
  24. OneHotEncoder
  25. OneHotEncoder
  26. OneHotEncoder
  27. OneHotEncoder
  28. OneHotEncoder
  29. OneHotEncoder
  30. OneHotEncoder
  31. OneHotEncoder
  32. OneHotEncoder
  33. VectorAssembler
  34. StandardScaler
  35. LogisticRegression


### 4.3. Train / Test Split

In [85]:
train_df, test_df = df_clean.randomSplit([0.8, 0.2], seed=42)
print(f'Training set : {train_df.count()} records')
print(f'Test set     : {test_df.count()} records')
print('\nClass distribution in training set:')
train_df.groupBy('label').count().orderBy('label').show()


Training set : 5690 records
Test set     : 1342 records

Class distribution in training set:
+-----+-----+
|label|count|
+-----+-----+
|    0| 4175|
|    1| 1515|
+-----+-----+



### 4.4. Define Hyperparameter Grid with ParamGridBuilder

In [86]:
lr_param_grid = (
    ParamGridBuilder()
    .addGrid(lr.regParam,        [0.01, 0.1, 0.5])
    .addGrid(lr.elasticNetParam, [0.0, 0.5])
    .addGrid(lr.maxIter,         [10, 50])
    .build()
)

print(f'LR ParamGrid built — {len(lr_param_grid)} combinations')
for i, params in enumerate(lr_param_grid):
    combo = {p.name: v for p, v in params.items()}
    print(f'  Combo {i+1:>2}: {combo}')


LR ParamGrid built — 12 combinations
  Combo  1: {'regParam': 0.01, 'elasticNetParam': 0.0, 'maxIter': 10}
  Combo  2: {'regParam': 0.01, 'elasticNetParam': 0.0, 'maxIter': 50}
  Combo  3: {'regParam': 0.01, 'elasticNetParam': 0.5, 'maxIter': 10}
  Combo  4: {'regParam': 0.01, 'elasticNetParam': 0.5, 'maxIter': 50}
  Combo  5: {'regParam': 0.1, 'elasticNetParam': 0.0, 'maxIter': 10}
  Combo  6: {'regParam': 0.1, 'elasticNetParam': 0.0, 'maxIter': 50}
  Combo  7: {'regParam': 0.1, 'elasticNetParam': 0.5, 'maxIter': 10}
  Combo  8: {'regParam': 0.1, 'elasticNetParam': 0.5, 'maxIter': 50}
  Combo  9: {'regParam': 0.5, 'elasticNetParam': 0.0, 'maxIter': 10}
  Combo 10: {'regParam': 0.5, 'elasticNetParam': 0.0, 'maxIter': 50}
  Combo 11: {'regParam': 0.5, 'elasticNetParam': 0.5, 'maxIter': 10}
  Combo 12: {'regParam': 0.5, 'elasticNetParam': 0.5, 'maxIter': 50}


### 4.5. Cross-Validation Setup (5-fold)

In [87]:
cv_evaluator = BinaryClassificationEvaluator(
    labelCol='label',
    rawPredictionCol='rawPrediction',
    metricName='areaUnderROC'
)

lr_cross_val = CrossValidator(
    estimator=lr_pipeline,
    estimatorParamMaps=lr_param_grid,
    evaluator=cv_evaluator,
    numFolds=5,
    seed=42
)


### 4.6. Train with Cross-Validation

In [88]:
lr_cv_model = lr_cross_val.fit(train_df)

print('\nAverage AUC-ROC per hyperparameter combination:')
print(f"{'Combo':<8} {'regParam':<10} {'elasticNet':<12} {'maxIter':<9} {'AUC-ROC':>8}")
print('-' * 53)
for i, (params, score) in enumerate(zip(lr_param_grid, lr_cv_model.avgMetrics)):
    combo = {p.name: v for p, v in params.items()}
    print(f"{i+1:<8} {combo['regParam']:<10} {combo['elasticNetParam']:<12} {combo['maxIter']:<9} {score:>8.4f}")

best_lr_score  = max(lr_cv_model.avgMetrics)
best_lr_idx    = lr_cv_model.avgMetrics.index(best_lr_score)
best_lr_params = {p.name: v for p, v in lr_param_grid[best_lr_idx].items()}
print(f'\nBest combo #{best_lr_idx+1}: {best_lr_params}  →  AUC = {best_lr_score:.4f}')



Average AUC-ROC per hyperparameter combination:
Combo    regParam   elasticNet   maxIter    AUC-ROC
-----------------------------------------------------
1        0.01       0.0          10          0.8389
2        0.01       0.0          50          0.8397
3        0.01       0.5          10          0.8401
4        0.01       0.5          50          0.8404
5        0.1        0.0          10          0.8367
6        0.1        0.0          50          0.8367
7        0.1        0.5          10          0.8334
8        0.1        0.5          50          0.8332
9        0.5        0.0          10          0.8301
10       0.5        0.0          50          0.8301
11       0.5        0.5          10          0.5000
12       0.5        0.5          50          0.5000

Best combo #4: {'regParam': 0.01, 'elasticNetParam': 0.5, 'maxIter': 50}  →  AUC = 0.8404


### 4.7. Evaluate Best Model on Test Set

In [89]:
lr_predictions = lr_cv_model.transform(test_df)

lr_acc = MulticlassClassificationEvaluator(labelCol='label', predictionCol='prediction', metricName='accuracy').evaluate(lr_predictions)
lr_f1 = MulticlassClassificationEvaluator(labelCol='label', predictionCol='prediction', metricName='f1').evaluate(lr_predictions)
lr_prec = MulticlassClassificationEvaluator(labelCol='label', predictionCol='prediction', metricName='weightedPrecision').evaluate(lr_predictions)
lr_rec = MulticlassClassificationEvaluator(labelCol='label', predictionCol='prediction', metricName='weightedRecall').evaluate(lr_predictions)
lr_auc = BinaryClassificationEvaluator(labelCol='label', rawPredictionCol='rawPrediction', metricName='areaUnderROC').evaluate(lr_predictions)

print("\nLogistic Regression Evaluation")
print("-" * 40)
print(f"  Accuracy          : {lr_acc:.4f}")
print(f"  F1 Score          : {lr_f1:.4f}")
print(f"  Weighted Precision: {lr_prec:.4f}")
print(f"  Weighted Recall   : {lr_rec:.4f}")
print(f"  AUC-ROC           : {lr_auc:.4f}")

print("\nConfusion Matrix:")
lr_predictions.groupBy('label', 'prediction').count().orderBy('label', 'prediction').show()


Logistic Regression Evaluation
----------------------------------------
  Accuracy          : 0.8137
  F1 Score          : 0.8050
  Weighted Precision: 0.8042
  Weighted Recall   : 0.8137
  AUC-ROC           : 0.8555

Confusion Matrix:
+-----+----------+-----+
|label|prediction|count|
+-----+----------+-----+
|    0|       0.0|  905|
|    0|       1.0|   83|
|    1|       0.0|  167|
|    1|       1.0|  187|
+-----+----------+-----+



---

## Phase 5: Random Forest Classification

### 5.1. Add RandomForestClassifier to the Pipeline

In [90]:
rf = RandomForestClassifier(
    featuresCol='features',
    labelCol='label',
    seed=42
)

rf_pipeline = Pipeline(stages=indexers + encoders + [vector_assembler, scaler, rf])

print("RF pipeline created with the following stages:")
for i, stage in enumerate(rf_pipeline.getStages()):
    print(f"  {i+1}. {stage.__class__.__name__}")

RF pipeline created with the following stages:
  1. StringIndexer
  2. StringIndexer
  3. StringIndexer
  4. StringIndexer
  5. StringIndexer
  6. StringIndexer
  7. StringIndexer
  8. StringIndexer
  9. StringIndexer
  10. StringIndexer
  11. StringIndexer
  12. StringIndexer
  13. StringIndexer
  14. StringIndexer
  15. StringIndexer
  16. StringIndexer
  17. OneHotEncoder
  18. OneHotEncoder
  19. OneHotEncoder
  20. OneHotEncoder
  21. OneHotEncoder
  22. OneHotEncoder
  23. OneHotEncoder
  24. OneHotEncoder
  25. OneHotEncoder
  26. OneHotEncoder
  27. OneHotEncoder
  28. OneHotEncoder
  29. OneHotEncoder
  30. OneHotEncoder
  31. OneHotEncoder
  32. OneHotEncoder
  33. VectorAssembler
  34. StandardScaler
  35. RandomForestClassifier


### 5.2. Define Hyperparameter Grid with ParamGridBuilder

In [91]:
rf_param_grid = (
    ParamGridBuilder()
    .addGrid(rf.numTrees,             [50, 70])
    .addGrid(rf.maxDepth,             [4, 6, 8])
    .addGrid(rf.minInstancesPerNode,  [1, 2])
    .build()
)

print(f"ParamGrid built — {len(rf_param_grid)} hyperparameter combinations to evaluate")
for i, params in enumerate(rf_param_grid):
    combo = {p.name: v for p, v in params.items()}
    print(f"  Combo {i+1:>2}: {combo}")


ParamGrid built — 12 hyperparameter combinations to evaluate
  Combo  1: {'numTrees': 50, 'maxDepth': 4, 'minInstancesPerNode': 1}
  Combo  2: {'numTrees': 50, 'maxDepth': 4, 'minInstancesPerNode': 2}
  Combo  3: {'numTrees': 50, 'maxDepth': 6, 'minInstancesPerNode': 1}
  Combo  4: {'numTrees': 50, 'maxDepth': 6, 'minInstancesPerNode': 2}
  Combo  5: {'numTrees': 50, 'maxDepth': 8, 'minInstancesPerNode': 1}
  Combo  6: {'numTrees': 50, 'maxDepth': 8, 'minInstancesPerNode': 2}
  Combo  7: {'numTrees': 70, 'maxDepth': 4, 'minInstancesPerNode': 1}
  Combo  8: {'numTrees': 70, 'maxDepth': 4, 'minInstancesPerNode': 2}
  Combo  9: {'numTrees': 70, 'maxDepth': 6, 'minInstancesPerNode': 1}
  Combo 10: {'numTrees': 70, 'maxDepth': 6, 'minInstancesPerNode': 2}
  Combo 11: {'numTrees': 70, 'maxDepth': 8, 'minInstancesPerNode': 1}
  Combo 12: {'numTrees': 70, 'maxDepth': 8, 'minInstancesPerNode': 2}


### 5.3. Cross-Validation Setup (5-fold)

In [92]:
rf_cross_val = CrossValidator(
    estimator=rf_pipeline,
    estimatorParamMaps=rf_param_grid,
    evaluator=cv_evaluator,
    numFolds=5,
    seed=42,
    parallelism=1
)

### 5.4. Train with Cross-Validation

In [93]:
rf_cv_model = rf_cross_val.fit(train_df)

print("\nAverage AUC-ROC per hyperparameter combination:")
print(f"{'Combo':<8} {'numTrees':<10} {'maxDepth':<10} {'minInst':<8} {'AUC-ROC':>8}")
print("-" * 50)
for i, (params, score) in enumerate(zip(rf_param_grid, rf_cv_model.avgMetrics)):
    combo = {p.name: v for p, v in params.items()}
    print(f"{i+1:<8} {combo['numTrees']:<10} {combo['maxDepth']:<10} {combo['minInstancesPerNode']:<8} {score:>8.4f}")

best_rf_score = max(rf_cv_model.avgMetrics)
best_rf_idx   = rf_cv_model.avgMetrics.index(best_rf_score)
best_rf_params = {p.name: v for p, v in rf_param_grid[best_rf_idx].items()}
print(f"\nBest combo #{best_rf_idx+1}: {best_rf_params}  →  AUC = {best_rf_score:.4f}")

26/02/22 21:29:15 WARN DAGScheduler: Broadcasting large task binary with size 1028.3 KiB
26/02/22 21:29:15 WARN DAGScheduler: Broadcasting large task binary with size 1572.8 KiB
26/02/22 21:29:17 WARN DAGScheduler: Broadcasting large task binary with size 1028.0 KiB
26/02/22 21:29:17 WARN DAGScheduler: Broadcasting large task binary with size 1562.9 KiB
26/02/22 21:29:25 WARN DAGScheduler: Broadcasting large task binary with size 1341.4 KiB
26/02/22 21:29:26 WARN DAGScheduler: Broadcasting large task binary with size 2.1 MiB
26/02/22 21:29:26 WARN DAGScheduler: Broadcasting large task binary with size 1281.4 KiB
26/02/22 21:29:28 WARN DAGScheduler: Broadcasting large task binary with size 1330.5 KiB
26/02/22 21:29:28 WARN DAGScheduler: Broadcasting large task binary with size 2.0 MiB
26/02/22 21:29:28 WARN DAGScheduler: Broadcasting large task binary with size 1175.7 KiB
26/02/22 21:29:36 WARN DAGScheduler: Broadcasting large task binary with size 1031.1 KiB
26/02/22 21:29:36 WARN DAGS


Average AUC-ROC per hyperparameter combination:
Combo    numTrees   maxDepth   minInst   AUC-ROC
--------------------------------------------------
1        50         4          1          0.8308
2        50         4          2          0.8311
3        50         6          1          0.8386
4        50         6          2          0.8383
5        50         8          1          0.8401
6        50         8          2          0.8404
7        70         4          1          0.8310
8        70         4          2          0.8308
9        70         6          1          0.8380
10       70         6          2          0.8380
11       70         8          1          0.8406
12       70         8          2          0.8402

Best combo #11: {'numTrees': 70, 'maxDepth': 8, 'minInstancesPerNode': 1}  →  AUC = 0.8406


### 5.5. Evaluate Best Model on Test Set

In [94]:
rf_predictions = rf_cv_model.transform(test_df)

rf_acc = MulticlassClassificationEvaluator(labelCol='label', predictionCol='prediction', metricName='accuracy').evaluate(rf_predictions)
rf_f1 = MulticlassClassificationEvaluator(labelCol='label', predictionCol='prediction', metricName='f1').evaluate(rf_predictions)
rf_prec = MulticlassClassificationEvaluator(labelCol='label', predictionCol='prediction', metricName='weightedPrecision').evaluate(rf_predictions)
rf_rec = MulticlassClassificationEvaluator(labelCol='label', predictionCol='prediction', metricName='weightedRecall').evaluate(rf_predictions)
rf_auc = BinaryClassificationEvaluator(labelCol='label', rawPredictionCol='rawPrediction', metricName='areaUnderROC').evaluate(rf_predictions)

print("\nRandom Forest Evaluation")
print("-" * 40)
print(f"  Accuracy          : {rf_acc:.4f}")
print(f"  F1 Score          : {rf_f1:.4f}")
print(f"  Weighted Precision: {rf_prec:.4f}")
print(f"  Weighted Recall   : {rf_rec:.4f}")
print(f"  AUC-ROC           : {rf_auc:.4f}")

print("\nConfusion Matrix:")
rf_predictions.groupBy('label', 'prediction').count().orderBy('label', 'prediction').show()


26/02/22 21:30:58 WARN DAGScheduler: Broadcasting large task binary with size 1272.7 KiB
26/02/22 21:30:58 WARN DAGScheduler: Broadcasting large task binary with size 1272.7 KiB
26/02/22 21:30:58 WARN DAGScheduler: Broadcasting large task binary with size 1272.7 KiB
26/02/22 21:30:58 WARN DAGScheduler: Broadcasting large task binary with size 1272.7 KiB
26/02/22 21:30:58 WARN DAGScheduler: Broadcasting large task binary with size 1261.0 KiB
26/02/22 21:30:58 WARN DAGScheduler: Broadcasting large task binary with size 1268.7 KiB



Random Forest Evaluation
----------------------------------------
  Accuracy          : 0.8040
  F1 Score          : 0.7916
  Weighted Precision: 0.7923
  Weighted Recall   : 0.8040
  AUC-ROC           : 0.8557

Confusion Matrix:
+-----+----------+-----+
|label|prediction|count|
+-----+----------+-----+
|    0|       0.0|  910|
|    0|       1.0|   78|
|    1|       0.0|  185|
|    1|       1.0|  169|
+-----+----------+-----+



26/02/22 21:30:58 WARN DAGScheduler: Broadcasting large task binary with size 1226.7 KiB
