In [0]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, when, trim, lower, regexp_replace, mean, stddev
from pyspark.sql.types import IntegerType, DoubleType, StringType

In [0]:
# Create Spark session
spark = SparkSession.builder.appName("ChurnDataCleaning").getOrCreate()

In [0]:
# Load raw data
df = spark.read.csv("synthetic_churn_uncleaned.csv", header=True, inferSchema=True)

In [0]:
# --------------------
# 1. Remove duplicates
# --------------------
df = df.dropDuplicates()

In [0]:
string_cols = [f.name for f in df.schema.fields if isinstance(f.dataType, StringType)]

In [0]:

# Lowercase, trim spaces for all string columns
for col_name in string_cols:
    df = df.withColumn(col_name, trim(lower(col(col_name))))

In [0]:
# Fix known categorical inconsistencies
df = df.withColumn("Gender", 
                   when(col("Gender").isin("male", "m"), "male")
                   .when(col("Gender").isin("female", "f"), "female")
                   .otherwise(col("Gender")))

In [0]:
df = df.withColumn("Internet Type", regexp_replace(col("Internet Type"), "fiber optic", "fiber optic"))

In [0]:
# --------------------
# 3. Handle missing values
# --------------------
# Fill numeric columns with median, categorical with mode
numeric_cols = [f.name for f in df.schema.fields if isinstance(f.dataType, (IntegerType, DoubleType))]
categorical_cols = [c for c in df.columns if c not in numeric_cols]

# Numeric: Fill with median
for col_name in numeric_cols:
    median_val = df.approxQuantile(col_name, [0.5], 0.01)[0]
    df = df.na.fill({col_name: median_val})

# Categorical: Fill with mode
for col_name in categorical_cols:
    mode_val = df.groupBy(col_name).count().orderBy(col("count").desc()).first()[0]
    df = df.na.fill({col_name: mode_val})

In [0]:
# --------------------
# 4. Ensure correct data types
# --------------------
binary_cols = ["Married", "Paperless Billing", "Online Security", 
               "Device Protection Plan", "Streaming TV", 
               "Streaming Movies", "Unlimited Data", "Churn"]

for col_name in binary_cols:
    df = df.withColumn(col_name, col(col_name).cast(IntegerType()))

df = df.withColumn("Monthly Charge", col("Monthly Charge").cast(DoubleType()))
df = df.withColumn("Total Charges", col("Total Charges").cast(DoubleType()))
df = df.withColumn("Satisfaction Score", col("Satisfaction Score").cast(IntegerType()))
df = df.withColumn("Number of Dependents", col("Number of Dependents").cast(IntegerType()))


In [0]:
# --------------------
# 5. Remove extreme outliers
# --------------------
for col_name in ["Monthly Charge", "Total Charges", "Age", "Tenure in Months"]:
    stats = df.select(mean(col_name).alias("mean"), stddev(col_name).alias("std")).collect()[0]
    mean_val, std_val = stats["mean"], stats["std"]
    df = df.filter(((col(col_name) - mean_val) / std_val).abs() < 3)


In [0]:
# --------------------
# 6. Derive useful non-ML columns
# --------------------
# Senior Citizen flag
df = df.withColumn("Senior Citizen", when(col("Age") >= 60, 1).otherwise(0))

# Flag for long tenure customers
df = df.withColumn("Long Tenure Flag", when(col("Tenure in Months") >= 24, 1).otherwise(0))

# --------------------
# 7. Final check
# --------------------
df.printSchema()
df.show(10, truncate=False)
