
# Customer Churn Prediction — PySpark MLlib (End-to-End)

**What this notebook does**  
1) Load `/mnt/data/churn.csv`  
2) Clean & prepare data (StringIndexer + OneHotEncoder + VectorAssembler)  
3) Train/test split (70/30)  
4) Train **Logistic Regression, Decision Tree, Random Forest**  
5) Evaluate with **AUC, Accuracy, Precision, Recall** + **Confusion Matrix**  
6) **Cross-Validation** on LR (light grid)  
7) **Feature Importances** (DT/RF)  
8) Save & load the best pipeline

> Tested on PySpark 3.x. If Spark is not installed, install with `pip install pyspark`.


In [None]:

# -----------------------------
# Setup: SparkSession & Imports
# -----------------------------
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import DoubleType
from pyspark.ml import Pipeline, PipelineModel
from pyspark.ml.feature import StringIndexer, OneHotEncoder, VectorAssembler
from pyspark.ml.classification import LogisticRegression, DecisionTreeClassifier, RandomForestClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder

# Create SparkSession
spark = (SparkSession.builder
         .appName("Churn-MLlib-Notebook")
         # .config("spark.driver.memory", "4g")  # uncomment if needed
         .getOrCreate())

data_path = "/mnt/data/churn.csv"  # set to your uploaded file path


In [None]:

# -----------------------------
# Load & Quick EDA
# -----------------------------
df_raw = (spark.read
          .option("header", True)
          .option("inferSchema", True)
          .csv(data_path))

print("Schema:")
df_raw.printSchema()

print("Sample rows:")
df_raw.show(10, truncate=False)

# Clean TotalCharges (common issue: blank strings)
df = df_raw.withColumn("total_charges",
                       F.when(F.length(F.trim(F.col("total_charges"))) == 0, None)
                        .otherwise(F.col("total_charges")).cast(DoubleType())
                      ) if "total_charges" in df_raw.columns else df_raw

# Standardize whitespace in string columns
for c, t in df.dtypes:
    if t == "string":
        df = df.withColumn(c, F.trim(F.col(c)))

# Basic class balance if 'churn' exists
if "churn" in [c.lower() for c in df.columns]:
    churn_col = [c for c in df.columns if c.lower() == "churn"][0]
    print("Churn counts:")
    df.groupBy(churn_col).count().orderBy(F.desc("count")).show()
else:
    print("No 'churn' column detected (case-insensitive).")


In [None]:

# -----------------------------
# Feature Engineering
# -----------------------------
# Define label column
label_col = [c for c in df.columns if c.lower() == "churn"]
assert len(label_col) == 1, "Label column 'churn' not found or ambiguous."
label_col = label_col[0]

# Identify numeric and categorical columns quickly
# You can pin exact columns if you want stricter control.
numeric_cols = [c for c, t in df.dtypes if t in ("int", "bigint", "double", "float")]
numeric_cols = [c for c in numeric_cols if c != label_col]

categorical_cols = [c for c, t in df.dtypes if t == "string"]
categorical_cols = [c for c in categorical_cols if c != label_col]

# Drop rows with nulls in critical columns (fast, robust)
critical = categorical_cols + numeric_cols + [label_col]
df = df.dropna(subset=critical)

# Label indexer
label_indexer = StringIndexer(inputCol=label_col, outputCol="label", handleInvalid="keep")

# Index + OneHot for categoricals
indexers = [StringIndexer(inputCol=c, outputCol=f"{c}_idx", handleInvalid="keep") for c in categorical_cols]
encoder  = OneHotEncoder(inputCols=[f"{c}_idx" for c in categorical_cols],
                         outputCols=[f"{c}_oh"  for c in categorical_cols])

# Assemble features
assembler = VectorAssembler(
    inputCols=[f"{c}_oh" for c in categorical_cols] + numeric_cols,
    outputCol="features"
)

# Train/test split
train, test = df.randomSplit([0.7, 0.3], seed=42)

print("Train:", train.count(), "Test:", test.count())
print("Numeric columns:", numeric_cols)
print("Categorical columns:", categorical_cols)


In [None]:

# -----------------------------
# Models: LR, DT, RF
# -----------------------------
lr = LogisticRegression(featuresCol="features", labelCol="label", maxIter=50, regParam=0.0, elasticNetParam=0.0)
dt = DecisionTreeClassifier(featuresCol="features", labelCol="label", maxDepth=8, minInstancesPerNode=10)
rf = RandomForestClassifier(featuresCol="features", labelCol="label", numTrees=120, maxDepth=12, subsamplingRate=0.8)

base_stages = [label_indexer] + indexers + [encoder, assembler]

lr_pipe = Pipeline(stages=base_stages + [lr])
dt_pipe = Pipeline(stages=base_stages + [dt])
rf_pipe = Pipeline(stages=base_stages + [rf])

lr_model = lr_pipe.fit(train)
dt_model = dt_pipe.fit(train)
rf_model = rf_pipe.fit(train)

print("Models trained.")


In [None]:

# -----------------------------
# Evaluation
# -----------------------------
bce = BinaryClassificationEvaluator(labelCol="label", metricName="areaUnderROC")
macc = MulticlassClassificationEvaluator(labelCol="label", metricName="accuracy")
mpre = MulticlassClassificationEvaluator(labelCol="label", metricName="weightedPrecision")
mrec = MulticlassClassificationEvaluator(labelCol="label", metricName="weightedRecall")

def evaluate(model, name):
    preds = model.transform(test).cache()
    auc = bce.evaluate(preds)
    acc = macc.evaluate(preds)
    pre = mpre.evaluate(preds)
    rec = mrec.evaluate(preds)
    print(f"\n{name} — AUC: {auc:.4f} | Acc: {acc:.4f} | Prec: {pre:.4f} | Recall: {rec:.4f}")
    cm = (preds.groupBy("label", "prediction").count().orderBy("label","prediction"))
    print("Confusion Matrix [label, prediction, count]:")
    cm.show()
    return {"preds": preds, "auc": auc, "acc": acc, "pre": pre, "rec": rec}

lr_metrics = evaluate(lr_model, "LogisticRegression")
dt_metrics = evaluate(dt_model, "DecisionTree")
rf_metrics = evaluate(rf_model, "RandomForest")

# Leaderboard by AUC
from collections import OrderedDict
leaderboard = OrderedDict(sorted({
    "LR": lr_metrics["auc"],
    "DT": dt_metrics["auc"],
    "RF": rf_metrics["auc"],
}.items(), key=lambda x: x[1], reverse=True))
print("\nModel AUC Leaderboard:", leaderboard)


In [None]:

# -----------------------------
# Cross-Validation (LR)
# -----------------------------
lr_cv = LogisticRegression(featuresCol="features", labelCol="label", maxIter=60)
lr_cv_pipe = Pipeline(stages=base_stages + [lr_cv])

paramGrid = (ParamGridBuilder()
             .addGrid(lr_cv.regParam, [0.0, 0.01, 0.1])
             .addGrid(lr_cv.elasticNetParam, [0.0, 0.5, 1.0])
             .build())

cv = CrossValidator(estimator=lr_cv_pipe,
                    estimatorParamMaps=paramGrid,
                    evaluator=bce,
                    numFolds=3,
                    parallelism=2)

cv_model = cv.fit(train)
cv_metrics = evaluate(cv_model.bestModel, "LR (CV best)")

print("Best LR Params:")
print("regParam:", cv_model.bestModel.stages[-1]._java_obj.parent().getRegParam())
print("elasticNetParam:", cv_model.bestModel.stages[-1]._java_obj.parent().getElasticNetParam())


In [None]:

# -----------------------------
# Feature Importances (DT/RF)
# -----------------------------
def print_top_importances(fitted_pipe, top_k=15):
    last = fitted_pipe.stages[-1]
    if hasattr(last, "featureImportances"):
        importances = last.featureImportances
        input_cols = assembler.getInputCols()
        pairs = list(enumerate(importances.toArray()))
        ranked = sorted(pairs, key=lambda x: x[1], reverse=True)[:top_k]
        print("\nTop feature importances:")
        for idx, val in ranked:
            print(f"{input_cols[idx]:<30} {val:.6f}")

print_top_importances(dt_model)
print_top_importances(rf_model)


In [None]:

# -----------------------------
# Save & Load Best Model
# -----------------------------
cand_models = [
    ("LR", lr_model, lr_metrics["auc"]),
    ("DT", dt_model, dt_metrics["auc"]),
    ("RF", rf_model, rf_metrics["auc"]),
    ("LR_CV", cv_model.bestModel, cv_metrics["auc"])
]

best_name, best_pipe, best_auc = sorted(cand_models, key=lambda x: x[2], reverse=True)[0]
print(f"\nSelected final model: {best_name} (AUC={best_auc:.4f})")

save_path = f"/mnt/data/models/{best_name}_pipeline"
best_pipe.write().overwrite().save(save_path)
print(f"Saved pipeline to: {save_path}")

loaded = PipelineModel.load(save_path)
loaded_auc = BinaryClassificationEvaluator(labelCol="label", metricName="areaUnderROC").evaluate(loaded.transform(test))
print(f"Loaded AUC: {loaded_auc:.4f}")


In [None]:

# Optional: stop Spark when done
# spark.stop()
