# AI-Driven Customer Churn Prediction
This notebook implements an end-to-end Databricks pipeline.



In [0]:
%sql
CREATE CATALOG IF NOT EXISTS ecommerce_ai;
USE CATALOG ecommerce_ai;

CREATE SCHEMA IF NOT EXISTS bronze;
CREATE SCHEMA IF NOT EXISTS silver;
CREATE SCHEMA IF NOT EXISTS gold;


com.databricks.backend.common.rpc.CommandCancelledException
	at com.databricks.spark.chauffeur.SequenceExecutionState.$anonfun$cancel$5(SequenceExecutionState.scala:139)
	at scala.Option.getOrElse(Option.scala:201)
	at com.databricks.spark.chauffeur.SequenceExecutionState.$anonfun$cancel$3(SequenceExecutionState.scala:139)
	at com.databricks.spark.chauffeur.SequenceExecutionState.$anonfun$cancel$3$adapted(SequenceExecutionState.scala:136)
	at scala.collection.immutable.Range.foreach(Range.scala:192)
	at com.databricks.spark.chauffeur.SequenceExecutionState.cancel(SequenceExecutionState.scala:136)
	at com.databricks.spark.chauffeur.ExecContextState.cancelRunningSequence(ExecContextState.scala:721)
	at com.databricks.spark.chauffeur.ExecContextState.$anonfun$cancel$1(ExecContextState.scala:441)
	at scala.Option.getOrElse(Option.scala:201)
	at com.databricks.spark.chauffeur.ExecContextState.cancel(ExecContextState.scala:441)
	at com.databricks.spark.chauffeur.ExecutionContextManagerV1.can

##Bronze Layer – Raw Data Ingestion

In [0]:
dbutils.fs.ls("/Volumes/ecommerce_ai/bronze/raw_data/")



[FileInfo(path='dbfs:/Volumes/ecommerce_ai/bronze/raw_data/data_ecommerce_customer_churn.csv', name='data_ecommerce_customer_churn.csv', size=206873, modificationTime=1769422637000)]

In [0]:
raw_df = spark.read.csv(
    "/Volumes/ecommerce_ai/bronze/raw_data/data_ecommerce_customer_churn.csv",
    header=True,
    inferSchema=True
)

display(raw_df)


Tenure,WarehouseToHome,NumberOfDeviceRegistered,PreferedOrderCat,SatisfactionScore,MaritalStatus,NumberOfAddress,Complain,DaySinceLastOrder,CashbackAmount,Churn
15.0,29.0,4,Laptop & Accessory,3,Single,2,0,7.0,143.32,0
7.0,25.0,4,Mobile,1,Married,2,0,7.0,129.29,0
27.0,13.0,3,Laptop & Accessory,1,Married,5,0,7.0,168.54,0
20.0,25.0,4,Fashion,3,Divorced,7,0,,230.27,0
30.0,15.0,4,Others,4,Single,8,0,8.0,322.17,0
7.0,16.0,4,Mobile Phone,2,Divorced,2,0,11.0,152.81,0
1.0,15.0,6,Mobile Phone,5,Divorced,3,0,2.0,149.51,0
1.0,11.0,4,Mobile Phone,5,Single,3,0,1.0,154.73,1
11.0,12.0,4,Mobile Phone,3,Married,2,0,4.0,137.02,0
17.0,7.0,3,Laptop & Accessory,1,Married,5,1,2.0,157.43,0


In [0]:
raw_df.printSchema()


root
 |-- Tenure: double (nullable = true)
 |-- WarehouseToHome: double (nullable = true)
 |-- NumberOfDeviceRegistered: integer (nullable = true)
 |-- PreferedOrderCat: string (nullable = true)
 |-- SatisfactionScore: integer (nullable = true)
 |-- MaritalStatus: string (nullable = true)
 |-- NumberOfAddress: integer (nullable = true)
 |-- Complain: integer (nullable = true)
 |-- DaySinceLastOrder: double (nullable = true)
 |-- CashbackAmount: double (nullable = true)
 |-- Churn: integer (nullable = true)



In [0]:
raw_df.write \
  .format("delta") \
  .mode("overwrite") \
  .saveAsTable("ecommerce_ai.bronze.customer_churn_raw")


In [0]:
%sql
SELECT COUNT(*) AS row_count
FROM ecommerce_ai.bronze.customer_churn_raw;


row_count
3941


In [0]:
%sql
SELECT * 
FROM ecommerce_ai.bronze.customer_churn_raw
LIMIT 5;


Tenure,WarehouseToHome,NumberOfDeviceRegistered,PreferedOrderCat,SatisfactionScore,MaritalStatus,NumberOfAddress,Complain,DaySinceLastOrder,CashbackAmount,Churn
15.0,29.0,4,Laptop & Accessory,3,Single,2,0,7.0,143.32,0
7.0,25.0,4,Mobile,1,Married,2,0,7.0,129.29,0
27.0,13.0,3,Laptop & Accessory,1,Married,5,0,7.0,168.54,0
20.0,25.0,4,Fashion,3,Divorced,7,0,,230.27,0
30.0,15.0,4,Others,4,Single,8,0,8.0,322.17,0


In [0]:
bronze_df = spark.table("ecommerce_ai.bronze.customer_churn_raw")
display(bronze_df)


Tenure,WarehouseToHome,NumberOfDeviceRegistered,PreferedOrderCat,SatisfactionScore,MaritalStatus,NumberOfAddress,Complain,DaySinceLastOrder,CashbackAmount,Churn
15.0,29.0,4,Laptop & Accessory,3,Single,2,0,7.0,143.32,0
7.0,25.0,4,Mobile,1,Married,2,0,7.0,129.29,0
27.0,13.0,3,Laptop & Accessory,1,Married,5,0,7.0,168.54,0
20.0,25.0,4,Fashion,3,Divorced,7,0,,230.27,0
30.0,15.0,4,Others,4,Single,8,0,8.0,322.17,0
7.0,16.0,4,Mobile Phone,2,Divorced,2,0,11.0,152.81,0
1.0,15.0,6,Mobile Phone,5,Divorced,3,0,2.0,149.51,0
1.0,11.0,4,Mobile Phone,5,Single,3,0,1.0,154.73,1
11.0,12.0,4,Mobile Phone,3,Married,2,0,4.0,137.02,0
17.0,7.0,3,Laptop & Accessory,1,Married,5,1,2.0,157.43,0


In [0]:
from pyspark.sql.functions import col, sum

# Null check
bronze_df.select([
    sum(col(c).isNull().cast("int")).alias(c)
    for c in bronze_df.columns
]).show()


+------+---------------+------------------------+----------------+-----------------+-------------+---------------+--------+-----------------+--------------+-----+
|Tenure|WarehouseToHome|NumberOfDeviceRegistered|PreferedOrderCat|SatisfactionScore|MaritalStatus|NumberOfAddress|Complain|DaySinceLastOrder|CashbackAmount|Churn|
+------+---------------+------------------------+----------------+-----------------+-------------+---------------+--------+-----------------+--------------+-----+
|   194|            169|                       0|               0|                0|            0|              0|       0|              213|             0|    0|
+------+---------------+------------------------+----------------+-----------------+-------------+---------------+--------+-----------------+--------------+-----+



##Silver Layer – Data Cleaning & Validation

In [0]:
from pyspark.sql.functions import col, when
from pyspark.sql.types import IntegerType

silver_df = bronze_df \
    .filter(col("Churn").isNotNull()) \
    .withColumn("Churn", col("Churn").cast(IntegerType()))

# Fill numeric nulls
numeric_cols = [
    "Tenure", "WarehouseToHome", "NumberOfDeviceRegistered",
    "SatisfactionScore", "NumberOfAddress",
    "DaySinceLastOrder", "CashbackAmount"
]

for c in numeric_cols:
    median_value = silver_df.approxQuantile(c, [0.5], 0.01)[0]
    silver_df = silver_df.fillna({c: median_value})

# Fill categorical nulls
silver_df = silver_df.fillna({
    "PreferedOrderCat": "Unknown",
    "MaritalStatus": "Unknown"
})


In [0]:
silver_df.write \
  .format("delta") \
  .mode("overwrite") \
  .saveAsTable("ecommerce_ai.silver.customer_churn_clean")


In [0]:
%sql
SELECT COUNT(*) FROM ecommerce_ai.silver.customer_churn_clean;


COUNT(*)
3941


In [0]:
silver_df = spark.table("ecommerce_ai.silver.customer_churn_clean")
display(silver_df)


Tenure,WarehouseToHome,NumberOfDeviceRegistered,PreferedOrderCat,SatisfactionScore,MaritalStatus,NumberOfAddress,Complain,DaySinceLastOrder,CashbackAmount,Churn
15.0,29.0,4,Laptop & Accessory,3,Single,2,0,7.0,143.32,0
7.0,25.0,4,Mobile,1,Married,2,0,7.0,129.29,0
27.0,13.0,3,Laptop & Accessory,1,Married,5,0,7.0,168.54,0
20.0,25.0,4,Fashion,3,Divorced,7,0,3.0,230.27,0
30.0,15.0,4,Others,4,Single,8,0,8.0,322.17,0
7.0,16.0,4,Mobile Phone,2,Divorced,2,0,11.0,152.81,0
1.0,15.0,6,Mobile Phone,5,Divorced,3,0,2.0,149.51,0
1.0,11.0,4,Mobile Phone,5,Single,3,0,1.0,154.73,1
11.0,12.0,4,Mobile Phone,3,Married,2,0,4.0,137.02,0
17.0,7.0,3,Laptop & Accessory,1,Married,5,1,2.0,157.43,0


##Gold Layer – Feature Engineering

In [0]:
from pyspark.sql.functions import when

gold_df = silver_df.withColumn(
    "TenureBucket",
    when(silver_df.Tenure < 6, "New")
    .when(silver_df.Tenure < 24, "Mid")
    .otherwise("Long")
)


In [0]:
gold_df = gold_df.withColumn(
    "HighRecencyRisk",
    when(gold_df.DaySinceLastOrder > 90, 1).otherwise(0)
)


In [0]:
gold_df = gold_df.withColumn(
    "ValueSegment",
    when(gold_df.CashbackAmount >= 300, "High")
    .when(gold_df.CashbackAmount >= 150, "Medium")
    .otherwise("Low")
)


In [0]:
from pyspark.ml.feature import StringIndexer

categorical_cols = [
    "PreferedOrderCat",
    "MaritalStatus",
    "TenureBucket",
    "ValueSegment"
]

indexers = [
    StringIndexer(
        inputCol=col,
        outputCol=f"{col}_idx",
        handleInvalid="keep"
    ) for col in categorical_cols
]


In [0]:
from pyspark.ml import Pipeline

pipeline = Pipeline(stages=indexers)
gold_df = pipeline.fit(gold_df).transform(gold_df)


In [0]:
feature_cols = [
    "Tenure",
    "WarehouseToHome",
    "NumberOfDeviceRegistered",
    "SatisfactionScore",
    "NumberOfAddress",
    "Complain",
    "DaySinceLastOrder",
    "CashbackAmount",
    "HighRecencyRisk",
    "PreferedOrderCat_idx",
    "MaritalStatus_idx",
    "TenureBucket_idx",
    "ValueSegment_idx"
]


In [0]:
label_col = "Churn"


In [0]:
gold_df.write \
  .format("delta") \
  .mode("overwrite") \
  .saveAsTable("ecommerce_ai.gold.customer_features")


In [0]:
%sql
SELECT COUNT(*) FROM ecommerce_ai.gold.customer_features;


COUNT(*)
3941


In [0]:
gold_df = spark.table("ecommerce_ai.gold.customer_features")
display(gold_df)


Tenure,WarehouseToHome,NumberOfDeviceRegistered,PreferedOrderCat,SatisfactionScore,MaritalStatus,NumberOfAddress,Complain,DaySinceLastOrder,CashbackAmount,Churn,TenureBucket,HighRecencyRisk,ValueSegment,PreferedOrderCat_idx,MaritalStatus_idx,TenureBucket_idx,ValueSegment_idx
15.0,29.0,4,Laptop & Accessory,3,Single,2,0,7.0,143.32,0,Mid,0,Low,0.0,1.0,0.0,1.0
7.0,25.0,4,Mobile,1,Married,2,0,7.0,129.29,0,Mid,0,Low,3.0,0.0,0.0,1.0
27.0,13.0,3,Laptop & Accessory,1,Married,5,0,7.0,168.54,0,Long,0,Medium,0.0,0.0,2.0,0.0
20.0,25.0,4,Fashion,3,Divorced,7,0,3.0,230.27,0,Mid,0,Medium,2.0,2.0,0.0,0.0
30.0,15.0,4,Others,4,Single,8,0,8.0,322.17,0,Long,0,High,5.0,1.0,2.0,2.0
7.0,16.0,4,Mobile Phone,2,Divorced,2,0,11.0,152.81,0,Mid,0,Medium,1.0,2.0,0.0,0.0
1.0,15.0,6,Mobile Phone,5,Divorced,3,0,2.0,149.51,0,New,0,Low,1.0,2.0,1.0,1.0
1.0,11.0,4,Mobile Phone,5,Single,3,0,1.0,154.73,1,New,0,Medium,1.0,1.0,1.0,0.0
11.0,12.0,4,Mobile Phone,3,Married,2,0,4.0,137.02,0,Mid,0,Low,1.0,0.0,0.0,1.0
17.0,7.0,3,Laptop & Accessory,1,Married,5,1,2.0,157.43,0,Mid,0,Medium,0.0,0.0,0.0,0.0


##Model Training & MLflow

In [0]:

from pyspark.ml.feature import VectorAssembler

feature_cols = [
    "Tenure",
    "WarehouseToHome",
    "NumberOfDeviceRegistered",
    "SatisfactionScore",
    "NumberOfAddress",
    "Complain",
    "DaySinceLastOrder",
    "CashbackAmount",
    "HighRecencyRisk",
    "PreferedOrderCat_idx",
    "MaritalStatus_idx",
    "TenureBucket_idx",
    "ValueSegment_idx"
]

assembler = VectorAssembler(
    inputCols=feature_cols,
    outputCol="features"
)

ml_df = assembler.transform(gold_df).select("features", "Churn")


In [0]:
train_df, test_df = ml_df.randomSplit([0.8, 0.2], seed=42)


In [0]:
from pyspark.ml.classification import LogisticRegression

lr = LogisticRegression(
    featuresCol="features",
    labelCol="Churn"
)


In [0]:
%sql
CREATE SCHEMA IF NOT EXISTS ecommerce_ai.ml;

CREATE VOLUME IF NOT EXISTS ecommerce_ai.ml.mlflow_artifacts;


In [0]:
import mlflow
import mlflow.spark

with mlflow.start_run(run_name="logistic_regression_churn"):

    model = lr.fit(train_df)
    predictions = model.transform(test_df)

    mlflow.spark.log_model(
        spark_model=model,
        artifact_path="model",
        dfs_tmpdir="/Volumes/ecommerce_ai/ml/mlflow_artifacts"
    )




In [0]:
from pyspark.sql.functions import col

scored_df = model.transform(gold_df)


{"ts": "2026-01-28 06:38:50.128", "level": "ERROR", "logger": "pyspark.sql.connect.logging", "msg": "GRPC Error received", "context": {}, "exception": {"class": "_InactiveRpcError", "msg": "<_InactiveRpcError of RPC that terminated with:\n\tstatus = StatusCode.INTERNAL\n\tdetails = \"[FIELD_NOT_FOUND] No such struct field `features` in `Tenure`, `WarehouseToHome`, `NumberOfDeviceRegistered`, `PreferedOrderCat`, `SatisfactionScore`, `MaritalStatus`, `NumberOfAddress`, `Complain`, `DaySinceLastOrder`, `CashbackAmount`, `Churn`, `TenureBucket`, `HighRecencyRisk`, `ValueSegment`, `PreferedOrderCat_idx`, `MaritalStatus_idx`, `TenureBucket_idx`, `ValueSegment_idx`. SQLSTATE: 42704\"\n\tdebug_error_string = \"UNKNOWN:Error received from peer  {created_time:\"2026-01-28T06:38:50.127905486+00:00\", grpc_status:13, grpc_message:\"[FIELD_NOT_FOUND] No such struct field `features` in `Tenure`, `WarehouseToHome`, `NumberOfDeviceRegistered`, `PreferedOrderCat`, `SatisfactionScore`, `MaritalStatus`, 

In [0]:
from pyspark.sql.functions import udf
from pyspark.sql.types import DoubleType

get_churn_prob = udf(lambda v: float(v[1]), DoubleType())

scored_df = scored_df.withColumn(
    "churn_probability",
    get_churn_prob(col("probability"))
)


{"ts": "2026-01-28 06:39:04.190", "level": "ERROR", "logger": "pyspark.sql.connect.logging", "msg": "GRPC Error received", "context": {}, "exception": {"class": "_InactiveRpcError", "msg": "<_InactiveRpcError of RPC that terminated with:\n\tstatus = StatusCode.INTERNAL\n\tdetails = \"[FIELD_NOT_FOUND] No such struct field `features` in `Tenure`, `WarehouseToHome`, `NumberOfDeviceRegistered`, `PreferedOrderCat`, `SatisfactionScore`, `MaritalStatus`, `NumberOfAddress`, `Complain`, `DaySinceLastOrder`, `CashbackAmount`, `Churn`, `TenureBucket`, `HighRecencyRisk`, `ValueSegment`, `PreferedOrderCat_idx`, `MaritalStatus_idx`, `TenureBucket_idx`, `ValueSegment_idx`. SQLSTATE: 42704\"\n\tdebug_error_string = \"UNKNOWN:Error received from peer  {created_time:\"2026-01-28T06:39:04.189971568+00:00\", grpc_status:13, grpc_message:\"[FIELD_NOT_FOUND] No such struct field `features` in `Tenure`, `WarehouseToHome`, `NumberOfDeviceRegistered`, `PreferedOrderCat`, `SatisfactionScore`, `MaritalStatus`, 

In [0]:
from pyspark.sql.functions import when

scored_df = scored_df.withColumn(
    "churn_risk_level",
    when(col("churn_probability") >= 0.75, "High")
    .when(col("churn_probability") >= 0.50, "Medium")
    .otherwise("Low")
)


{"ts": "2026-01-28 06:39:08.492", "level": "ERROR", "logger": "pyspark.sql.connect.logging", "msg": "GRPC Error received", "context": {}, "exception": {"class": "_InactiveRpcError", "msg": "<_InactiveRpcError of RPC that terminated with:\n\tstatus = StatusCode.INTERNAL\n\tdetails = \"[FIELD_NOT_FOUND] No such struct field `features` in `Tenure`, `WarehouseToHome`, `NumberOfDeviceRegistered`, `PreferedOrderCat`, `SatisfactionScore`, `MaritalStatus`, `NumberOfAddress`, `Complain`, `DaySinceLastOrder`, `CashbackAmount`, `Churn`, `TenureBucket`, `HighRecencyRisk`, `ValueSegment`, `PreferedOrderCat_idx`, `MaritalStatus_idx`, `TenureBucket_idx`, `ValueSegment_idx`. SQLSTATE: 42704\"\n\tdebug_error_string = \"UNKNOWN:Error received from peer  {created_time:\"2026-01-28T06:39:08.491790971+00:00\", grpc_status:13, grpc_message:\"[FIELD_NOT_FOUND] No such struct field `features` in `Tenure`, `WarehouseToHome`, `NumberOfDeviceRegistered`, `PreferedOrderCat`, `SatisfactionScore`, `MaritalStatus`, 

In [0]:
scored_df = scored_df.withColumn(
    "retention_action",
    when(
        (col("churn_risk_level") == "High") & (col("ValueSegment") == "High"),
        "Offer High-Value Discount"
    ).when(
        (col("churn_risk_level") == "High") & (col("ValueSegment") == "Medium"),
        "Offer Loyalty Points"
    ).when(
        col("churn_risk_level") == "Medium",
        "Send Personalized Reminder"
    ).otherwise(
        "No Action Needed"
    )
)



{"ts": "2026-01-28 06:39:13.721", "level": "ERROR", "logger": "pyspark.sql.connect.logging", "msg": "GRPC Error received", "context": {}, "exception": {"class": "_InactiveRpcError", "msg": "<_InactiveRpcError of RPC that terminated with:\n\tstatus = StatusCode.INTERNAL\n\tdetails = \"[FIELD_NOT_FOUND] No such struct field `features` in `Tenure`, `WarehouseToHome`, `NumberOfDeviceRegistered`, `PreferedOrderCat`, `SatisfactionScore`, `MaritalStatus`, `NumberOfAddress`, `Complain`, `DaySinceLastOrder`, `CashbackAmount`, `Churn`, `TenureBucket`, `HighRecencyRisk`, `ValueSegment`, `PreferedOrderCat_idx`, `MaritalStatus_idx`, `TenureBucket_idx`, `ValueSegment_idx`. SQLSTATE: 42704\"\n\tdebug_error_string = \"UNKNOWN:Error received from peer  {grpc_message:\"[FIELD_NOT_FOUND] No such struct field `features` in `Tenure`, `WarehouseToHome`, `NumberOfDeviceRegistered`, `PreferedOrderCat`, `SatisfactionScore`, `MaritalStatus`, `NumberOfAddress`, `Complain`, `DaySinceLastOrder`, `CashbackAmount`, 

In [0]:
scored_df = scored_df.withColumn(
    "priority_score",
    col("churn_probability") *
    when(col("ValueSegment") == "High", 3)
    .when(col("ValueSegment") == "Medium", 2)
    .otherwise(1)
)


{"ts": "2026-01-26 10:51:21.101", "level": "ERROR", "logger": "pyspark.sql.connect.logging", "msg": "GRPC Error received", "context": {}, "exception": {"class": "_InactiveRpcError", "msg": "<_InactiveRpcError of RPC that terminated with:\n\tstatus = StatusCode.INTERNAL\n\tdetails = \"[FIELD_NOT_FOUND] No such struct field `features` in `Tenure`, `WarehouseToHome`, `NumberOfDeviceRegistered`, `PreferedOrderCat`, `SatisfactionScore`, `MaritalStatus`, `NumberOfAddress`, `Complain`, `DaySinceLastOrder`, `CashbackAmount`, `Churn`, `TenureBucket`, `HighRecencyRisk`, `ValueSegment`, `PreferedOrderCat_idx`, `MaritalStatus_idx`, `TenureBucket_idx`, `ValueSegment_idx`. SQLSTATE: 42704\"\n\tdebug_error_string = \"UNKNOWN:Error received from peer  {created_time:\"2026-01-26T10:51:21.09849419+00:00\", grpc_status:13, grpc_message:\"[FIELD_NOT_FOUND] No such struct field `features` in `Tenure`, `WarehouseToHome`, `NumberOfDeviceRegistered`, `PreferedOrderCat`, `SatisfactionScore`, `MaritalStatus`, `

In [0]:
final_df = scored_df.select(
    "Tenure",
    "PreferedOrderCat",
    "MaritalStatus",
    "ValueSegment",
    "churn_probability",
    "churn_risk_level",
    "retention_action",
    "priority_score"
)


{"ts": "2026-01-28 06:39:24.073", "level": "ERROR", "logger": "pyspark.sql.connect.logging", "msg": "GRPC Error received", "context": {}, "exception": {"class": "_InactiveRpcError", "msg": "<_InactiveRpcError of RPC that terminated with:\n\tstatus = StatusCode.INTERNAL\n\tdetails = \"[FIELD_NOT_FOUND] No such struct field `features` in `Tenure`, `WarehouseToHome`, `NumberOfDeviceRegistered`, `PreferedOrderCat`, `SatisfactionScore`, `MaritalStatus`, `NumberOfAddress`, `Complain`, `DaySinceLastOrder`, `CashbackAmount`, `Churn`, `TenureBucket`, `HighRecencyRisk`, `ValueSegment`, `PreferedOrderCat_idx`, `MaritalStatus_idx`, `TenureBucket_idx`, `ValueSegment_idx`. SQLSTATE: 42704\"\n\tdebug_error_string = \"UNKNOWN:Error received from peer  {grpc_message:\"[FIELD_NOT_FOUND] No such struct field `features` in `Tenure`, `WarehouseToHome`, `NumberOfDeviceRegistered`, `PreferedOrderCat`, `SatisfactionScore`, `MaritalStatus`, `NumberOfAddress`, `Complain`, `DaySinceLastOrder`, `CashbackAmount`, 

##AI Innovation – Scoring & Recommendations

In [0]:
from pyspark.ml.feature import VectorAssembler

assembler = VectorAssembler(
    inputCols=feature_cols,
    outputCol="features"
)

scoring_df = assembler.transform(gold_df)


In [0]:
scored_df = model.transform(scoring_df)


In [0]:
from pyspark.sql.functions import udf, col
from pyspark.sql.types import DoubleType

get_churn_prob = udf(lambda v: float(v[1]), DoubleType())

scored_df = scored_df.withColumn(
    "churn_probability",
    get_churn_prob(col("probability"))
)


In [0]:
from pyspark.sql.functions import when

scored_df = scored_df.withColumn(
    "churn_risk_level",
    when(col("churn_probability") >= 0.75, "High")
    .when(col("churn_probability") >= 0.50, "Medium")
    .otherwise("Low")
)

scored_df = scored_df.withColumn(
    "retention_action",
    when(
        (col("churn_risk_level") == "High") & (col("ValueSegment") == "High"),
        "Offer High-Value Discount"
    ).when(
        (col("churn_risk_level") == "High") & (col("ValueSegment") == "Medium"),
        "Offer Loyalty Points"
    ).when(
        col("churn_risk_level") == "Medium",
        "Send Personalized Reminder"
    ).otherwise(
        "No Action Needed"
    )
)

scored_df = scored_df.withColumn(
    "priority_score",
    col("churn_probability") *
    when(col("ValueSegment") == "High", 3)
    .when(col("ValueSegment") == "Medium", 2)
    .otherwise(1)
)


In [0]:
final_df = scored_df.select(
    "Tenure",
    "PreferedOrderCat",
    "MaritalStatus",
    "ValueSegment",
    "churn_probability",
    "churn_risk_level",
    "retention_action",
    "priority_score"
)


In [0]:
from pyspark.sql.functions import col

final_df.filter(col("retention_action") != "No Action Needed") \
  .write \
  .format("delta") \
  .mode("overwrite") \
  .saveAsTable("ecommerce_ai.gold.retention_recommendations")



In [0]:
%sql
SHOW TABLES IN ecommerce_ai.gold;


database,tableName,isTemporary
gold,customer_churn_scores,False
gold,customer_features,False
gold,retention_recommendations,False


##Analytics & Business Insights

In [0]:
%sql
SELECT retention_action, COUNT(*) 
FROM ecommerce_ai.gold.retention_recommendations
GROUP BY retention_action;


retention_action,COUNT(*)
Send Personalized Reminder,286
Offer Loyalty Points,30


In [0]:
%sql
SELECT
  ValueSegment,
  churn_probability,
  priority_score,
  retention_action
FROM ecommerce_ai.gold.retention_recommendations
ORDER BY priority_score DESC
LIMIT 20;


ValueSegment,churn_probability,priority_score,retention_action
High,0.7369559238249035,2.2108677714747103,Send Personalized Reminder
High,0.7036166620136542,2.110849986040962,Send Personalized Reminder
Medium,0.9182604458439196,1.8365208916878395,Offer Loyalty Points
Medium,0.9182604458439196,1.8365208916878395,Offer Loyalty Points
Medium,0.8960863789846226,1.7921727579692452,Offer Loyalty Points
Medium,0.8960863789846226,1.7921727579692452,Offer Loyalty Points
Medium,0.8887696855417138,1.7775393710834275,Offer Loyalty Points
Medium,0.8598167709521313,1.7196335419042623,Offer Loyalty Points
Medium,0.8419543358912638,1.6839086717825276,Offer Loyalty Points
Medium,0.8350489210374245,1.670097842074849,Offer Loyalty Points


In [0]:
%sql
SELECT 
  churn_risk_level,
  COUNT(*) AS customer_count
FROM ecommerce_ai.gold.customer_churn_scores
GROUP BY churn_risk_level
ORDER BY customer_count DESC;


churn_risk_level,customer_count
Low,3513
Medium,286
High,142
