In [0]:
from pyspark.sql.functions import col, lit, when, round as spark_round
import pyspark.sql.functions as F

# -------------------------
# CONFIG
# -------------------------

INCOME_STEP = 250
FPL_SINGLE = 15000

# -------------------------
# LOAD TABLES (WITH ALIASES)
# -------------------------

demo_users = spark.table("workspace.bronze.demo_users").alias("u")
silver = spark.table("workspace.silver.subsidy_adjusted_premiums").alias("s")
subsidy_params = spark.table("workspace.bronze.subsidy_params")

params = subsidy_params.first()
CLIFF_FPL = float(params.cliff_fpl)

# -------------------------
# SNAP USER INCOME TO GRID
# -------------------------

demo_users = demo_users.withColumn(
    "income_snapped",
    (spark_round(col("u.annual_income") / INCOME_STEP) * INCOME_STEP).cast("int")
)

# -------------------------
# JOIN USERS WITH PREMIUM SURFACE
# -------------------------

user_plan = (
    demo_users
    .join(
        silver,
        (col("s.income") == col("income_snapped")) &
        (col("s.county") == col("u.demo_county")),
        how="inner"
    )
)

# -------------------------
# DISTANCE TO CLIFF
# -------------------------

user_plan = user_plan.withColumn(
    "distance_to_cliff",
    F.greatest(
        lit(0),
        lit(CLIFF_FPL * FPL_SINGLE) - col("u.annual_income")
    )
)

# -------------------------
# ELASTICITY RATIO
# -------------------------

user_plan = user_plan.withColumn(
    "elasticity_ratio",
    when(
        col("s.net_premium") > 0,
        col("s.fragility_slope") *
        (col("u.annual_income") / col("s.net_premium"))
    )
)

# -------------------------
# STABILITY CLASSIFICATION
# -------------------------

user_plan = user_plan.withColumn(
    "stability_classification",
    when(col("distance_to_cliff") < 1500, "Cliff-Prone")
    .when(col("s.fragility_slope") > 0.05, "Moderately Sensitive")
    .otherwise("Stable")
)

# -------------------------
# FINAL SELECT + WRITE
# -------------------------
spark.sql("DROP TABLE IF EXISTS workspace.gold.cliff_proximity_metrics")
(
    user_plan
    .select(
        col("u.user_id").alias("user_id"),
        col("s.plan_id").alias("plan_id"),
        col("u.annual_income").alias("annual_income"),
        col("s.net_premium").alias("net_premium"),
        col("s.fragility_slope").alias("fragility_slope"),
        "distance_to_cliff",
        "elasticity_ratio",
        "stability_classification"
    )
    .write
    .format("delta")
    .mode("overwrite")
    .saveAsTable("workspace.gold.cliff_proximity_metrics")
)