In [0]:
from pyspark.sql.functions import col
from pyspark.sql.types import StructType, StructField, StringType, DoubleType
import numpy as np
import pandas as pd

# -----------------------
# CONFIG
# -----------------------

N_SIM = 10000
np.random.seed(42)

# -----------------------
# LOAD TABLES
# -----------------------

demo_users = spark.table("workspace.bronze.demo_users")
plans = spark.table("workspace.bronze.plans")

# Limit demo to Fulton plans
plans_demo = plans.filter(col("county") == "Fulton").limit(5)
plans_pd = plans_demo.toPandas()

users_pd = demo_users.toPandas()

# -----------------------
# COST DISTRIBUTIONS
# -----------------------

OFFICE_MEDIAN = 150
OFFICE_SIGMA = 0.6

ER_MEDIAN = 1200
ER_SIGMA = 0.9

MED_LOW = (10, 60)
MED_HIGH = (100, 600)

def lognormal_params(median, sigma):
    mu = np.log(median)
    return mu, sigma

mu_off, sigma_off = lognormal_params(OFFICE_MEDIAN, OFFICE_SIGMA)
mu_er, sigma_er = lognormal_params(ER_MEDIAN, ER_SIGMA)

results = []

# -----------------------
# LOOP OVER USERS
# -----------------------

for _, user in users_pd.iterrows():

    user_id = user["user_id"]
    med_count = user["medication_count"]
    er_lambda = max(user["expected_er_visits"], 0.05)
    therapy_lambda = max(user["therapy_frequency"] * 12, 1)

    # Medication risk adjustment
    high_med_prob = min(0.2 + 0.05 * med_count, 0.6)

    # -----------------------
    # LOOP OVER PLANS
    # -----------------------

    for _, plan in plans_pd.iterrows():

        deductible = float(plan["deductible"])
        oop_max = float(plan["oop_max"]) if plan["oop_max"] else deductible * 2
        coinsurance = 0.2

        # -------- UTILIZATION --------

        office_counts = np.random.poisson(therapy_lambda, N_SIM)
        er_counts = np.random.poisson(er_lambda, N_SIM)

        med_random = np.random.rand(N_SIM)
        med_monthly = np.where(
            med_random < (1 - high_med_prob),
            np.random.uniform(MED_LOW[0], MED_LOW[1], N_SIM),
            np.random.uniform(MED_HIGH[0], MED_HIGH[1], N_SIM)
        )
        med_total = med_monthly * 12

        # -------- COST SAMPLING --------

        office_costs = np.array([
            np.random.lognormal(mu_off, sigma_off, count).sum()
            if count > 0 else 0
            for count in office_counts
        ])

        er_costs = np.array([
            np.random.lognormal(mu_er, sigma_er, count).sum()
            if count > 0 else 0
            for count in er_counts
        ])

        annual_allowed = office_costs + er_costs + med_total

        # -------- OOP CALCULATION --------

        oop = np.where(
            annual_allowed <= deductible,
            annual_allowed,
            deductible + coinsurance * (annual_allowed - deductible)
        )

        oop = np.minimum(oop, oop_max)

        # -------- METRICS --------

        breach_prob = np.mean(annual_allowed > deductible)
        mean_oop = np.mean(oop)
        p90 = np.percentile(oop, 90)

        results.append((
            str(user_id),
            str(plan["plan_id"]),
            float(breach_prob),
            float(mean_oop),
            float(p90)
        ))

# -----------------------
# WRITE TO GOLD
# -----------------------

schema = StructType([
    StructField("user_id", StringType(), True),
    StructField("plan_id", StringType(), True),
    StructField("breach_probability", DoubleType(), True),
    StructField("mean_oop", DoubleType(), True),
    StructField("p90_exposure", DoubleType(), True),
])

spark.sql("DROP TABLE IF EXISTS workspace.gold.monte_carlo_risk_metrics")

gold_mc = spark.createDataFrame(results, schema=schema)

gold_mc.write.format("delta") \
    .mode("overwrite") \
    .saveAsTable("workspace.gold.monte_carlo_risk_metrics")