In [0]:
# MACHINE LEARNING - Churn Prediction
# Purpose: Predict which customers are likely to churn

from pyspark.ml.feature import VectorAssembler, StandardScaler
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml import Pipeline
from pyspark.sql import functions as F
from pyspark.sql.functions import percent_rank
from pyspark.sql.window import Window
from pyspark.ml.evaluation import BinaryClassificationEvaluator

print("="*80)
print("Machine Learning: Churn Prediction Model")
print("="*80)

# Load data
customer_data = spark.table("workspace.default.gold_customer_metrics")

# ============================================
# CHURN DEFINITION
# ============================================

# For e-commerce, 180 days (6 months) is more realistic
churn_threshold = 180 

print(f"\nðŸ“Š Churn Definition: Customer inactive for {churn_threshold}+ days")

# Create target variable
customer_data = customer_data.withColumn(
    "is_churned",
    F.when(F.col("recency_days") > churn_threshold, 1).otherwise(0)
)

# Verify churn distribution
churn_distribution = customer_data.groupBy("is_churned").count()
total_customers = customer_data.count()

print("\nðŸ“ˆ Churn Distribution:")
churn_distribution.show()

churn_count = customer_data.filter("is_churned = 1").count()
churn_rate = (churn_count / total_customers) * 100
print(f"   Churn Rate: {churn_rate:.2f}%")
print(f"   Active Customers: {total_customers - churn_count:,}")
print(f"   Churned Customers: {churn_count:,}")

# Feature engineering
features_for_model = [
    "recency_days",
    "frequency",
    "monetary_value",
    "avg_order_value",
    "avg_delivery_days",
    "avg_delivery_delay",
    "avg_review_score",
    "customer_lifetime_days",
    "total_items_purchased"
]

# Prepare training data (remove nulls and cast to double)
ml_data = customer_data.select(
    "customer_unique_id",
    F.col("is_churned").cast("double").alias("is_churned"),
    *[F.col(c).cast("double").alias(c) for c in features_for_model]
).na.drop()

# Split data
train_data, test_data = ml_data.randomSplit([0.8, 0.2], seed=42)

print(f"\nðŸ“š Dataset Split:")
print(f"   Training set: {train_data.count():,} customers")
print(f"   Test set: {test_data.count():,} customers")

# Build pipeline
assembler = VectorAssembler(inputCols=features_for_model, outputCol="features_raw")
scaler = StandardScaler(inputCol="features_raw", outputCol="features")

# Random Forest
rf = RandomForestClassifier(
    labelCol="is_churned",
    featuresCol="features",
    numTrees=10,
    maxDepth=5,
    seed=42
)

pipeline = Pipeline(stages=[assembler, scaler, rf])

# Train model
print("\nðŸ”„ Training Random Forest model...")
model = pipeline.fit(train_data)
print("   âœ… Model trained successfully!")

# Make predictions on ALL customers
print("\nðŸ”® Making predictions on all customers...")
predictions = model.transform(ml_data)

# Extract churn probability using a UDF
@F.udf("double")
def extract_probability(probability):
    """Extract probability of positive class (churn=1)"""
    if probability is not None:
        return float(probability[1])
    return 0.0

predictions = predictions.withColumn(
    "churn_probability",
    extract_probability(F.col("probability"))
)

# ============================================
# PREDICTED CHURN
# ============================================

# Create predicted_churn based on probability,
# using 0.5 (50%) as threshold for classification
predictions = predictions.withColumn(
    "predicted_churn",
    F.when(F.col("churn_probability") >= 0.5, 1).otherwise(0).cast("int")
)

# ============================================
# RISK CATEGORIES
# ============================================

# Calculate risk percentile
window_spec = Window.orderBy(F.col("churn_probability").desc())

predictions = predictions.withColumn(
    "churn_percentile",
    percent_rank().over(window_spec)
)

# Divide categories based in percentiles
predictions = predictions.withColumn(
    "churn_risk_category",
    F.when(F.col("churn_percentile") <= 0.20, "High Risk")      # Top 20%
     .when(F.col("churn_percentile") <= 0.50, "Medium Risk")    # Next 30%
     .otherwise("Low Risk")                                     # Bottom 50%
)

# Select only necessary columns
predictions_final = predictions.select(
    "customer_unique_id",
    F.col("is_churned").cast("int").alias("is_churned"),
    "churn_probability",
    "churn_risk_category",
    "predicted_churn"
)

# Save predictions
predictions_final.write \
    .format("delta") \
    .mode("overwrite") \
    .option("overwriteSchema", "true") \
    .saveAsTable("workspace.default.gold_customer_churn_predictions")

print("   âœ… Created gold_customer_churn_predictions")

# ============================================
# METRICS AND ANALYSIS
# ============================================

print("\nðŸ“ˆ Model Evaluation:")
test_predictions = model.transform(test_data)

evaluator = BinaryClassificationEvaluator(
    labelCol="is_churned", 
    metricName="areaUnderROC"
)
auc = evaluator.evaluate(test_predictions)
print(f"   AUC-ROC Score: {auc:.3f}")

# Show risk distribution
print("\nðŸ“Š Churn Risk Distribution:")
risk_dist = predictions_final.groupBy("churn_risk_category") \
    .agg(
        F.count("*").alias("customer_count"),
        F.avg("churn_probability").alias("avg_probability")
    ) \
    .orderBy(F.desc("avg_probability"))

risk_dist.show()

# Churn rate by predicted_churn
print("\nðŸŽ¯ Predicted Churn Summary:")
pred_summary = predictions_final.groupBy("predicted_churn").count()
pred_summary.show()

predicted_churn_count = predictions_final.filter("predicted_churn = 1").count()
predicted_churn_rate = (predicted_churn_count / predictions_final.count()) * 100
print(f"   Predicted Churn Rate: {predicted_churn_rate:.2f}%")

print("\n" + "="*80)
print("ðŸŽ‰ Churn Prediction Complete!")
print("="*80)
print(f"\nKey Takeaways:")
print(f"  â€¢ Definition: Churn = {churn_threshold}+ days inactive")
print(f"  â€¢ Actual Churn Rate: {churn_rate:.2f}%")
print(f"  â€¢ Predicted Churn Rate: {predicted_churn_rate:.2f}%")
print(f"  â€¢ Model AUC-ROC: {auc:.3f}")
print("="*80)