---

# Customer Churn Prediction with SparkML
**EPITA â€“ MSc Artificial Intelligence Systems (AIS)**  
**Spark & Python for Big Data AIS S2 F25**

**Students:** 
- TRUONG Kim Tan
- LE Linh Long
- George
- Farouk

---

## Phase 1: Setup and Data Loading

In [119]:
import pandas as pd
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import StructType, StructField
from pyspark.ml import Pipeline
from pyspark.ml.feature import StringIndexer, OneHotEncoder, VectorAssembler, StandardScaler
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator

### 1.1. Iniialize SparkSession

In [120]:
spark = SparkSession.builder \
    .appName("ChurnPredictionPipeline") \
    .config("spark.sql.shuffle.partitions", "4") \
    .getOrCreate()

print(spark)

<pyspark.sql.session.SparkSession object at 0x1148dc510>


### 1.2. Load the Dataset

In [121]:
temp_df = pd.read_csv("WA_Fn-UseC_-Telco-Customer-Churn.csv", nrows=1)
cols = temp_df.columns.tolist()

# build schema
schema = StructType([StructField(c, StringType(), True) for c in cols])

df = spark.read.csv("WA_Fn-UseC_-Telco-Customer-Churn.csv", header=True, schema=schema)

### 1.3. Initial Data Inspection

In [122]:
print(f"Total rows: {df.count()}")
df.printSchema()
df.show()

Total rows: 7043
root
 |-- customerID: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- SeniorCitizen: string (nullable = true)
 |-- Partner: string (nullable = true)
 |-- Dependents: string (nullable = true)
 |-- tenure: string (nullable = true)
 |-- PhoneService: string (nullable = true)
 |-- MultipleLines: string (nullable = true)
 |-- InternetService: string (nullable = true)
 |-- OnlineSecurity: string (nullable = true)
 |-- OnlineBackup: string (nullable = true)
 |-- DeviceProtection: string (nullable = true)
 |-- TechSupport: string (nullable = true)
 |-- StreamingTV: string (nullable = true)
 |-- StreamingMovies: string (nullable = true)
 |-- Contract: string (nullable = true)
 |-- PaperlessBilling: string (nullable = true)
 |-- PaymentMethod: string (nullable = true)
 |-- MonthlyCharges: string (nullable = true)
 |-- TotalCharges: string (nullable = true)
 |-- Churn: string (nullable = true)

+----------+------+-------------+-------+----------+------+-------

---

## Phase 2: Exploratory Data Analysis (EDA) & Data Cleaning

### 2.1. Data Cleaning (Handling Missing Values)

**a) Column count check for Missing Values**

In [123]:
def count_missing(c, dtype):
    if isinstance(dtype, StringType):
        return F.count(
            F.when(
                F.col(c).isNull() | (F.trim(F.col(c)) == ""), 
                c
            )
        ).alias(c)
    else:
        return F.count(F.when(F.col(c).isNull(), c)).alias(c)

missing_counts = df.select([
    count_missing(c, dtype) for c, dtype in df.dtypes
])

missing_data = missing_counts.collect()[0].asDict()

print(f"{'Column':<20} {'Missing Count':>15}")
print("-" * 40)
for col_name, count in missing_data.items():
    print(f"{col_name:<20} {count:>15}")

Column                 Missing Count
----------------------------------------
customerID                         0
gender                             0
SeniorCitizen                      0
Partner                            0
Dependents                         0
tenure                             0
PhoneService                       0
MultipleLines                      0
InternetService                    0
OnlineSecurity                     0
OnlineBackup                       0
DeviceProtection                   0
TechSupport                        0
StreamingTV                        0
StreamingMovies                    0
Contract                           0
PaperlessBilling                   0
PaymentMethod                      0
MonthlyCharges                     0
TotalCharges                       0
Churn                              0


**b) Data Cleaning**

In [124]:

# Handle TotalCharges: replace empty strings with null, then cast to Double
df = df.withColumn('TotalCharges', 
                   F.when(F.col('TotalCharges').isin("", " "), None)
                   .otherwise(F.col('TotalCharges')))

# Drop rows with null TotalCharges (11 records)
print(f"\nRows before dropping null TotalCharges: {df.count()}")
df_clean = df.na.drop(subset=['TotalCharges'])
print(f"Rows after dropping null TotalCharges: {df_clean.count()}")

# Check for duplicates
print(f"\nDuplicate customerID count: {df_clean.count() - df_clean.select('customerID').distinct().count()}")


Rows before dropping null TotalCharges: 7043
Rows after dropping null TotalCharges: 7032

Duplicate customerID count: 0


### 2.2. Data preparation

In [125]:
df_clean = df_clean.withColumn('TotalCharges', F.col('TotalCharges').cast(DoubleType()))

# Cast SeniorCitizen to Integer (here the values are currently "0", "1")
df_clean = df_clean.withColumn('SeniorCitizen', F.col('SeniorCitizen').cast(IntegerType()))

# Cast tenure to Integer
df_clean = df_clean.withColumn('tenure', F.col('tenure').cast(IntegerType()))

# Cast MonthlyCharges to Double
df_clean = df_clean.withColumn('MonthlyCharges', F.col('MonthlyCharges').cast(DoubleType()))

print("Schema after conversion of columns:")
df_clean.printSchema()

Schema after conversion of columns:
root
 |-- customerID: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- SeniorCitizen: integer (nullable = true)
 |-- Partner: string (nullable = true)
 |-- Dependents: string (nullable = true)
 |-- tenure: integer (nullable = true)
 |-- PhoneService: string (nullable = true)
 |-- MultipleLines: string (nullable = true)
 |-- InternetService: string (nullable = true)
 |-- OnlineSecurity: string (nullable = true)
 |-- OnlineBackup: string (nullable = true)
 |-- DeviceProtection: string (nullable = true)
 |-- TechSupport: string (nullable = true)
 |-- StreamingTV: string (nullable = true)
 |-- StreamingMovies: string (nullable = true)
 |-- Contract: string (nullable = true)
 |-- PaperlessBilling: string (nullable = true)
 |-- PaymentMethod: string (nullable = true)
 |-- MonthlyCharges: double (nullable = true)
 |-- TotalCharges: double (nullable = true)
 |-- Churn: string (nullable = true)



### 2.3. Univariate Analysis (Analyzing Single Variables)

**a) Numerical Features**

In [126]:
# numerical columns summary
numerical_cols = ['tenure', 'MonthlyCharges', 'TotalCharges']
df_clean.select(numerical_cols).describe().show()

# SeniorCitizen belongs to categorical column
print("SeniorCitizen Distribution:")
df_clean.groupBy('SeniorCitizen').count().show()

+-------+------------------+------------------+------------------+
|summary|            tenure|    MonthlyCharges|      TotalCharges|
+-------+------------------+------------------+------------------+
|  count|              7032|              7032|              7032|
|   mean|32.421786120591584| 64.79820819112632|2283.3004408418697|
| stddev|24.545259709263245|30.085973884049825| 2266.771361883145|
|    min|                 1|             18.25|              18.8|
|    max|                72|            118.75|            8684.8|
+-------+------------------+------------------+------------------+

SeniorCitizen Distribution:
+-------------+-----+
|SeniorCitizen|count|
+-------------+-----+
|            0| 5890|
|            1| 1142|
+-------------+-----+



**b) Categorical Features**

In [127]:
categorical_cols = [c for c in df_clean.columns if c not in numerical_cols + ['customerID', 'SeniorCitizen']]
print(f"Categorical columns: {categorical_cols}")

# Show value counts for key categorical columns
key_categorical = ['gender', 'Partner', 'Dependents', 'Contract', 
                   'InternetService', 'PaymentMethod', 'Churn']

for col in key_categorical:
    print(f"{col} distribution:")
    df_clean.groupBy(col).count().orderBy(F.desc('count')).show()

Categorical columns: ['gender', 'Partner', 'Dependents', 'PhoneService', 'MultipleLines', 'InternetService', 'OnlineSecurity', 'OnlineBackup', 'DeviceProtection', 'TechSupport', 'StreamingTV', 'StreamingMovies', 'Contract', 'PaperlessBilling', 'PaymentMethod', 'Churn']
gender distribution:
+------+-----+
|gender|count|
+------+-----+
|  Male| 3549|
|Female| 3483|
+------+-----+

Partner distribution:
+-------+-----+
|Partner|count|
+-------+-----+
|     No| 3639|
|    Yes| 3393|
+-------+-----+

Dependents distribution:
+----------+-----+
|Dependents|count|
+----------+-----+
|        No| 4933|
|       Yes| 2099|
+----------+-----+

Contract distribution:
+--------------+-----+
|      Contract|count|
+--------------+-----+
|Month-to-month| 3875|
|      Two year| 1685|
|      One year| 1472|
+--------------+-----+

InternetService distribution:
+---------------+-----+
|InternetService|count|
+---------------+-----+
|    Fiber optic| 3096|
|            DSL| 2416|
|             No| 1520|


### 2.4. Bivariate Analysis (Analyzing Relationships)

In [128]:
print("Churn rate by Contract type:")
df_clean.groupBy('Contract').agg(
    F.count('*').alias('total'),
    F.sum(F.when(F.col('Churn') == 'Yes', 1).otherwise(0)).alias('churned'),
    (F.sum(F.when(F.col('Churn') == 'Yes', 1).otherwise(0)) / F.count('*') * 100).alias('churn_rate')
).show()

print("Churn rate by Internet Service:")
df_clean.groupBy('InternetService').agg(
    F.count('*').alias('total'),
    F.sum(F.when(F.col('Churn') == 'Yes', 1).otherwise(0)).alias('churned'),
    (F.sum(F.when(F.col('Churn') == 'Yes', 1).otherwise(0)) / F.count('*') * 100).alias('churn_rate')
).show()

# Tenure statistics by Churn
print("Tenure statistics by Churn:")
df_clean.groupBy('Churn').agg(
    F.avg('tenure').alias('avg_tenure'),
    F.min('tenure').alias('min_tenure'),
    F.max('tenure').alias('max_tenure'),
    F.avg('MonthlyCharges').alias('avg_monthly_charges'),
    F.avg('TotalCharges').alias('avg_total_charges')
).show()

Churn rate by Contract type:
+--------------+-----+-------+------------------+
|      Contract|total|churned|        churn_rate|
+--------------+-----+-------+------------------+
|Month-to-month| 3875|   1655| 42.70967741935484|
|      One year| 1472|    166|11.277173913043478|
|      Two year| 1685|     48|2.8486646884272995|
+--------------+-----+-------+------------------+

Churn rate by Internet Service:
+---------------+-----+-------+------------------+
|InternetService|total|churned|        churn_rate|
+---------------+-----+-------+------------------+
|             No| 1520|    113| 7.434210526315789|
|            DSL| 2416|    459|18.998344370860927|
|    Fiber optic| 3096|   1297| 41.89276485788114|
+---------------+-----+-------+------------------+

Tenure statistics by Churn:
+-----+------------------+----------+----------+-------------------+------------------+
|Churn|        avg_tenure|min_tenure|max_tenure|avg_monthly_charges| avg_total_charges|
+-----+------------------+

---

## Phase 3: Data Transformation & Feature Engineering

### 3.1. Identify Feature Columns

In [129]:
# we identify feature columns (except for customerID and target)
feature_cols = [c for c in df_clean.columns if c not in ['customerID', 'Churn']]

# we separate columns into categorical and numerical
cat_cols = [c for c in feature_cols if c not in numerical_cols]
num_cols = [c for c in feature_cols if c in numerical_cols]

print(f"Features to use: {feature_cols}")
print(f"Categorical: {cat_cols}")
print(f"Numerical: {num_cols}")

# we convert the target variable to numeric
df_clean = df_clean.withColumn('label', F.when(F.col('Churn') == 'Yes', 1).otherwise(0))

Features to use: ['gender', 'SeniorCitizen', 'Partner', 'Dependents', 'tenure', 'PhoneService', 'MultipleLines', 'InternetService', 'OnlineSecurity', 'OnlineBackup', 'DeviceProtection', 'TechSupport', 'StreamingTV', 'StreamingMovies', 'Contract', 'PaperlessBilling', 'PaymentMethod', 'MonthlyCharges', 'TotalCharges']
Categorical: ['gender', 'SeniorCitizen', 'Partner', 'Dependents', 'PhoneService', 'MultipleLines', 'InternetService', 'OnlineSecurity', 'OnlineBackup', 'DeviceProtection', 'TechSupport', 'StreamingTV', 'StreamingMovies', 'Contract', 'PaperlessBilling', 'PaymentMethod']
Numerical: ['tenure', 'MonthlyCharges', 'TotalCharges']


### 3.2. Define Pipeline Stages

**a) Categorical Encoding**

In [130]:
indexers = []
encoders = []
indexed_cols = []
encoded_cols = []

for col in cat_cols:
    indexer = StringIndexer(inputCol=col, outputCol=f"{col}_indexed", handleInvalid='keep')
    encoder = OneHotEncoder(inputCol=f"{col}_indexed", outputCol=f"{col}_encoded", dropLast=False)
    indexers.append(indexer)
    encoders.append(encoder)
    indexed_cols.append(f"{col}_indexed")
    encoded_cols.append(f"{col}_encoded")

**b) Vector Assembly**

In [131]:
# combine numerical features with encoded categorical features
assembler_inputs = num_cols + encoded_cols
vector_assembler = VectorAssembler(inputCols=assembler_inputs, outputCol='features_unscaled')

**c) Feature Scaling**

In [132]:
scaler = StandardScaler(inputCol='features_unscaled', outputCol='features', 
                        withStd=True, withMean=True)

**d) Create Pipeline**

In [133]:
pipeline_stages = indexers + encoders + [vector_assembler, scaler]
pipeline = Pipeline(stages=pipeline_stages)

print(f"Pipeline stages: {len(indexers)} indexers, {len(encoders)} encoders, 1 assembler, 1 scaler")
print(f"Total stages in pipeline: {len(pipeline_stages)}")

Pipeline stages: 16 indexers, 16 encoders, 1 assembler, 1 scaler
Total stages in pipeline: 34


---

## Phase 4: Build ML Pipeline, Train & Evaluate

### 4.1. Define the Model

In [134]:
lr = LogisticRegression(featuresCol='features', labelCol='label', 
                        maxIter=10, regParam=0.01, elasticNetParam=0.0)

### 4.2. Assemble the Pipeline

In [135]:
pipeline = Pipeline(stages=indexers + encoders + [vector_assembler, scaler, lr])

print("Pipeline created successfully with the following stages:")
for i, stage in enumerate(pipeline.getStages()):
    stage_name = stage.__class__.__name__
    print(f"  {i+1}. {stage_name}")

Pipeline created successfully with the following stages:
  1. StringIndexer
  2. StringIndexer
  3. StringIndexer
  4. StringIndexer
  5. StringIndexer
  6. StringIndexer
  7. StringIndexer
  8. StringIndexer
  9. StringIndexer
  10. StringIndexer
  11. StringIndexer
  12. StringIndexer
  13. StringIndexer
  14. StringIndexer
  15. StringIndexer
  16. StringIndexer
  17. OneHotEncoder
  18. OneHotEncoder
  19. OneHotEncoder
  20. OneHotEncoder
  21. OneHotEncoder
  22. OneHotEncoder
  23. OneHotEncoder
  24. OneHotEncoder
  25. OneHotEncoder
  26. OneHotEncoder
  27. OneHotEncoder
  28. OneHotEncoder
  29. OneHotEncoder
  30. OneHotEncoder
  31. OneHotEncoder
  32. OneHotEncoder
  33. VectorAssembler
  34. StandardScaler
  35. LogisticRegression


In [136]:
# split the data for 5.1
train_df, test_df = df_clean.randomSplit([0.8, 0.2], seed=42)
print(f"Training set: {train_df.count()} records")
print(f"Test set: {test_df.count()} records")

# see class distribution
print("Class distribution in training set:")
train_df.groupBy('label').count().show()

Training set: 5690 records
Test set: 1342 records
Class distribution in training set:
+-----+-----+
|label|count|
+-----+-----+
|    0| 4175|
|    1| 1515|
+-----+-----+



In [137]:
# train
model = pipeline.fit(train_df)

In [138]:
# test
predictions = model.transform(test_df)

print("Sample predictions:")
predictions.select('customerID', 'Churn', 'label', 'prediction', 'probability').show(10, truncate=False)


--- 4.5 Making Predictions ---
Sample predictions:
+----------+-----+-----+----------+-----------------------------------------+
|customerID|Churn|label|prediction|probability                              |
+----------+-----+-----+----------+-----------------------------------------+
|0004-TLHLJ|Yes  |1    |1.0       |[0.3542472128970759,0.645752787102924]   |
|0013-SMEOE|No   |0    |0.0       |[0.960249300395742,0.03975069960425803]  |
|0015-UOCOJ|No   |0    |0.0       |[0.5980236992924435,0.4019763007075565]  |
|0019-EFAEP|No   |0    |0.0       |[0.9546219641705298,0.0453780358294702]  |
|0023-HGHWL|Yes  |1    |1.0       |[0.389878594351926,0.6101214056480739]   |
|0030-FNXPP|No   |0    |0.0       |[0.8412278128259559,0.15877218717404407] |
|0042-RLHYP|No   |0    |0.0       |[0.9921528552325698,0.007847144767430203]|
|0057-QBUQH|No   |0    |0.0       |[0.9650239865928006,0.03497601340719936] |
|0078-XZMHT|No   |0    |0.0       |[0.9731822276916164,0.02681777230838356] |
|0080-EMYVY|

In [139]:
# evaluate

# Accuracy
evaluator_acc = MulticlassClassificationEvaluator(labelCol='label', predictionCol='prediction', 
                                                  metricName='accuracy')
accuracy = evaluator_acc.evaluate(predictions)
print(f"Accuracy: {accuracy:.4f}")

# F1 Score
evaluator_f1 = MulticlassClassificationEvaluator(labelCol='label', predictionCol='prediction', 
                                                 metricName='f1')
f1_score = evaluator_f1.evaluate(predictions)
print(f"F1 Score: {f1_score:.4f}")

# Precision
evaluator_precision = MulticlassClassificationEvaluator(labelCol='label', predictionCol='prediction', 
                                                       metricName='weightedPrecision')
precision = evaluator_precision.evaluate(predictions)
print(f"Precision: {precision:.4f}")

# Recall
evaluator_recall = MulticlassClassificationEvaluator(labelCol='label', predictionCol='prediction', 
                                                    metricName='weightedRecall')
recall = evaluator_recall.evaluate(predictions)
print(f"Recall: {recall:.4f}")

# AUC-ROC
evaluator_auc = BinaryClassificationEvaluator(labelCol='label', rawPredictionCol='rawPrediction', 
                                             metricName='areaUnderROC')
auc = evaluator_auc.evaluate(predictions)
print(f"AUC-ROC: {auc:.4f}")

# Confusion Matrix
print("Confusion Matrix:")
predictions.groupBy('label', 'prediction').count().orderBy('label', 'prediction').show()

Accuracy: 0.8122
F1 Score: 0.8044
Precision: 0.8030
Recall: 0.8122
AUC-ROC: 0.8535
Confusion Matrix:
+-----+----------+-----+
|label|prediction|count|
+-----+----------+-----+
|    0|       0.0|  900|
|    0|       1.0|   88|
|    1|       0.0|  164|
|    1|       1.0|  190|
+-----+----------+-----+



In [140]:
# Stop Spark
spark.stop()