In [0]:
# Load CSV into a Spark DataFrame
df = spark.read.csv("/FileStore/tables/WA_Fn_UseC__Telco_Customer_Churn.csv", header=True, inferSchema=True)

# Display first few rows
df.display()


customerID,gender,SeniorCitizen,Partner,Dependents,tenure,PhoneService,MultipleLines,InternetService,OnlineSecurity,OnlineBackup,DeviceProtection,TechSupport,StreamingTV,StreamingMovies,Contract,PaperlessBilling,PaymentMethod,MonthlyCharges,TotalCharges,Churn
7590-VHVEG,Female,0,Yes,No,1,No,No phone service,DSL,No,Yes,No,No,No,No,Month-to-month,Yes,Electronic check,29.85,29.85,No
5575-GNVDE,Male,0,No,No,34,Yes,No,DSL,Yes,No,Yes,No,No,No,One year,No,Mailed check,56.95,1889.5,No
3668-QPYBK,Male,0,No,No,2,Yes,No,DSL,Yes,Yes,No,No,No,No,Month-to-month,Yes,Mailed check,53.85,108.15,Yes
7795-CFOCW,Male,0,No,No,45,No,No phone service,DSL,Yes,No,Yes,Yes,No,No,One year,No,Bank transfer (automatic),42.3,1840.75,No
9237-HQITU,Female,0,No,No,2,Yes,No,Fiber optic,No,No,No,No,No,No,Month-to-month,Yes,Electronic check,70.7,151.65,Yes
9305-CDSKC,Female,0,No,No,8,Yes,Yes,Fiber optic,No,No,Yes,No,Yes,Yes,Month-to-month,Yes,Electronic check,99.65,820.5,Yes
1452-KIOVK,Male,0,No,Yes,22,Yes,Yes,Fiber optic,No,Yes,No,No,Yes,No,Month-to-month,Yes,Credit card (automatic),89.1,1949.4,No
6713-OKOMC,Female,0,No,No,10,No,No phone service,DSL,Yes,No,No,No,No,No,Month-to-month,No,Mailed check,29.75,301.9,No
7892-POOKP,Female,0,Yes,No,28,Yes,Yes,Fiber optic,No,No,Yes,Yes,Yes,Yes,Month-to-month,Yes,Electronic check,104.8,3046.05,Yes
6388-TABGU,Male,0,No,Yes,62,Yes,No,DSL,Yes,Yes,No,No,No,No,One year,No,Bank transfer (automatic),56.15,3487.95,No


In [0]:
df.printSchema()

root
 |-- customerID: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- SeniorCitizen: integer (nullable = true)
 |-- Partner: string (nullable = true)
 |-- Dependents: string (nullable = true)
 |-- tenure: integer (nullable = true)
 |-- PhoneService: string (nullable = true)
 |-- MultipleLines: string (nullable = true)
 |-- InternetService: string (nullable = true)
 |-- OnlineSecurity: string (nullable = true)
 |-- OnlineBackup: string (nullable = true)
 |-- DeviceProtection: string (nullable = true)
 |-- TechSupport: string (nullable = true)
 |-- StreamingTV: string (nullable = true)
 |-- StreamingMovies: string (nullable = true)
 |-- Contract: string (nullable = true)
 |-- PaperlessBilling: string (nullable = true)
 |-- PaymentMethod: string (nullable = true)
 |-- MonthlyCharges: double (nullable = true)
 |-- TotalCharges: string (nullable = true)
 |-- Churn: string (nullable = true)



In [0]:
# View 10 random rows for important columns
df.select("customerID", "gender", "SeniorCitizen", "tenure", "MonthlyCharges", "TotalCharges", "Churn").display()


customerID,gender,SeniorCitizen,tenure,MonthlyCharges,TotalCharges,Churn
7590-VHVEG,Female,0,1,29.85,29.85,No
5575-GNVDE,Male,0,34,56.95,1889.5,No
3668-QPYBK,Male,0,2,53.85,108.15,Yes
7795-CFOCW,Male,0,45,42.3,1840.75,No
9237-HQITU,Female,0,2,70.7,151.65,Yes
9305-CDSKC,Female,0,8,99.65,820.5,Yes
1452-KIOVK,Male,0,22,89.1,1949.4,No
6713-OKOMC,Female,0,10,29.75,301.9,No
7892-POOKP,Female,0,28,104.8,3046.05,Yes
6388-TABGU,Male,0,62,56.15,3487.95,No


In [0]:
# Check for nulls or empty strings
from pyspark.sql.functions import col, count, when
df = df.withColumn("TotalCharges", col("TotalCharges").cast("float"))
df.select([count(when(col(c).isNull() | (col(c) == ''), c)).alias(c) for c in df.columns]).display()

customerID,gender,SeniorCitizen,Partner,Dependents,tenure,PhoneService,MultipleLines,InternetService,OnlineSecurity,OnlineBackup,DeviceProtection,TechSupport,StreamingTV,StreamingMovies,Contract,PaperlessBilling,PaymentMethod,MonthlyCharges,TotalCharges,Churn
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,11,0


In [0]:
# Register as temporary view
df.createOrReplaceTempView("churn_data")

# Run SQL to count total rows and churn breakdown
spark.sql("""
    SELECT 
        COUNT(*) AS total_customers,
        SUM(CASE WHEN Churn = 'Yes' THEN 1 ELSE 0 END) AS churned_customers,
        ROUND(100 * SUM(CASE WHEN Churn = 'Yes' THEN 1 ELSE 0 END) / COUNT(*), 2) AS churn_rate_percentage
    FROM churn_data
""").display()


total_customers,churned_customers,churn_rate_percentage
7043,1869,26.54


In [0]:
# Drop rows where TotalCharges is null (after casting it to float)
df = df.filter(df.TotalCharges.isNotNull())

In [0]:
# List all columns with string (categorical) types
string_cols = [f.name for f in df.schema.fields if f.dataType.simpleString() == 'string']
print("Categorical columns:", string_cols)

Categorical columns: ['customerID', 'gender', 'Partner', 'Dependents', 'PhoneService', 'MultipleLines', 'InternetService', 'OnlineSecurity', 'OnlineBackup', 'DeviceProtection', 'TechSupport', 'StreamingTV', 'StreamingMovies', 'Contract', 'PaperlessBilling', 'PaymentMethod', 'Churn']


In [0]:
from pyspark.ml.feature import StringIndexer

# Index all string columns (including Churn)
indexers = [StringIndexer(inputCol=col, outputCol=col+"_indexed") for col in string_cols]

from pyspark.ml import Pipeline
pipeline = Pipeline(stages=indexers)
df_indexed = pipeline.fit(df).transform(df)

In [0]:
df_cleaned = df_indexed.drop(*string_cols)
df_cleaned.printSchema()


root
 |-- SeniorCitizen: integer (nullable = true)
 |-- tenure: integer (nullable = true)
 |-- MonthlyCharges: double (nullable = true)
 |-- TotalCharges: float (nullable = true)
 |-- customerID_indexed: double (nullable = false)
 |-- gender_indexed: double (nullable = false)
 |-- Partner_indexed: double (nullable = false)
 |-- Dependents_indexed: double (nullable = false)
 |-- PhoneService_indexed: double (nullable = false)
 |-- MultipleLines_indexed: double (nullable = false)
 |-- InternetService_indexed: double (nullable = false)
 |-- OnlineSecurity_indexed: double (nullable = false)
 |-- OnlineBackup_indexed: double (nullable = false)
 |-- DeviceProtection_indexed: double (nullable = false)
 |-- TechSupport_indexed: double (nullable = false)
 |-- StreamingTV_indexed: double (nullable = false)
 |-- StreamingMovies_indexed: double (nullable = false)
 |-- Contract_indexed: double (nullable = false)
 |-- PaperlessBilling_indexed: double (nullable = false)
 |-- PaymentMethod_indexed: do

In [0]:
# Clean up features
excluded_cols = ['customerID', 'customerID_indexed', 'Churn_indexed']
feature_cols = [col for col in df_cleaned.columns if col not in excluded_cols]

In [0]:
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml import Pipeline

# Step 1: Assemble features
assembler = VectorAssembler(inputCols=feature_cols, outputCol="features")

# Step 2: Initialize Random Forest
rf = RandomForestClassifier(labelCol="Churn_indexed", featuresCol="features", numTrees=100)

# Step 3: Create Pipeline
pipeline = Pipeline(stages=[assembler, rf])

In [0]:
# Split into training and testing sets
train_data, test_data = df_cleaned.randomSplit([0.8, 0.2], seed=42)

In [0]:
# Train the pipeline model
pipeline_model = pipeline.fit(train_data)

In [0]:
# Predict on test data
rf_predictions = pipeline_model.transform(test_data)
rf_predictions.select("Churn_indexed", "prediction", "probability").show(10, truncate=False)

+-------------+----------+----------------------------------------+
|Churn_indexed|prediction|probability                             |
+-------------+----------+----------------------------------------+
|1.0          |0.0       |[0.7621700174072787,0.23782998259272123]|
|1.0          |0.0       |[0.7621700174072787,0.23782998259272123]|
|0.0          |0.0       |[0.7661273260979605,0.23387267390203947]|
|0.0          |0.0       |[0.7448443908495276,0.2551556091504724] |
|1.0          |0.0       |[0.7621700174072787,0.23782998259272123]|
|0.0          |0.0       |[0.7621700174072787,0.23782998259272123]|
|0.0          |0.0       |[0.7661273260979605,0.23387267390203947]|
|1.0          |0.0       |[0.7661273260979605,0.23387267390203947]|
|0.0          |0.0       |[0.7661273260979605,0.23387267390203947]|
|0.0          |0.0       |[0.8914649245673393,0.10853507543266079]|
+-------------+----------+----------------------------------------+
only showing top 10 rows



In [0]:
from pyspark.ml.evaluation import BinaryClassificationEvaluator

evaluator = BinaryClassificationEvaluator(labelCol="Churn_indexed")
rf_auc = evaluator.evaluate(rf_predictions)
print(f"Random Forest ROC AUC: {rf_auc:.4f}")

Random Forest ROC AUC: 0.8312


In [0]:
from pyspark.sql.functions import udf
from pyspark.sql.types import FloatType

# Step 1: Extract churn probability from prediction vector (index 1 = 'Yes')
get_churn_prob = udf(lambda x: float(x[1]), FloatType())
export_df = rf_predictions.withColumn("churn_probability", get_churn_prob(rf_predictions.probability))

# Step 2: Select only relevant columns
final_export = export_df.select(
    "customerID_indexed",
    "Churn_indexed",
    "prediction",
    "churn_probability"
)


final_export.display()

customerID_indexed,Churn_indexed,prediction,churn_probability
6549.0,1.0,0.0,0.23782998
4328.0,1.0,0.0,0.23782998
6628.0,0.0,0.0,0.23387267
2956.0,0.0,0.0,0.25515562
3255.0,1.0,0.0,0.23782998
5178.0,0.0,0.0,0.23782998
5942.0,0.0,0.0,0.23387267
698.0,1.0,0.0,0.23387267
2658.0,0.0,0.0,0.23387267
3939.0,0.0,0.0,0.10853507


In [0]:
# Convert Spark DataFrame to CSV as string rows
csv_rows = final_export.toPandas().to_csv(index=False)

# Save to DBFS using dbutils
dbutils.fs.put("/FileStore/churn_predictions.csv", csv_rows, overwrite=True)


Wrote 34583 bytes.
Out[25]: True

In [0]:
df = spark.read.csv("/FileStore/churn_predictions.csv", header=True, inferSchema=True)

# Display first few rows
df.display()


customerID_indexed,Churn_indexed,prediction,churn_probability
6549.0,1.0,0.0,0.23782998
4328.0,1.0,0.0,0.23782998
6628.0,0.0,0.0,0.23387267
2956.0,0.0,0.0,0.25515562
3255.0,1.0,0.0,0.23782998
5178.0,0.0,0.0,0.23782998
5942.0,0.0,0.0,0.23387267
698.0,1.0,0.0,0.23387267
2658.0,0.0,0.0,0.23387267
3939.0,0.0,0.0,0.10853507
