In [0]:
from pyspark.sql import SparkSession, Row
from pyspark.sql.functions import col, upper, split, trim, when, round, lit, count
from pyspark.sql.types import DoubleType
from datetime import datetime
import uuid
import traceback

# ---------------------------
# Initialize Spark Session
# ---------------------------
spark = SparkSession.builder.appName("HealthcareSilverLayer").getOrCreate()

# ---------------------------
# Job ID for this run
# ---------------------------
job_id = str(uuid.uuid4())

# ---------------------------
# Logging function
# ---------------------------
def log_step(step_name, status="INFO", message=""):
    log_df = spark.createDataFrame([
        Row(
            job_id=job_id,
            timestamp=datetime.utcnow().isoformat(),
            step=step_name,
            status=status,
            message=message
        )
    ])
    (
        log_df.write
        .format("delta")
        .mode("append")
        .saveAsTable("logs.silver_logs.log")
    )

# ---------------------------
# Helper function: count nulls per column
# ---------------------------
def null_count(df, columns):
    return {col_name: df.filter(col(col_name).isNull()).count() for col_name in columns}

try:
    log_step("START", "INFO", "Healthcare silver layer pipeline started")

    # ---------------------------
    # Load Bronze Layer
    # ---------------------------
    bronze_df = spark.read.table("processed_outputs.bronze.patient")
    log_step("LOAD_BRONZE", "INFO", f"Loaded {bronze_df.count()} rows from Bronze layer")

    # ---------------------------
    # Step 1: Deduplicate and Drop Bad Rows
    # ---------------------------
    before_count = bronze_df.count()
    silver_df = bronze_df.dropDuplicates()
    after_dedup_count = silver_df.count()
    required_cols = ["State", "Sex", "AgeCategory", "GeneralHealth"]
    before_dropna_count = silver_df.count()
    silver_df = silver_df.dropna(subset=required_cols)
    after_dropna_count = silver_df.count()
    log_step("DEDUPLICATE", "INFO",
             f"Rows before dedup: {before_count}, after dedup: {after_dedup_count}, "
             f"before dropna: {before_dropna_count}, after dropna: {after_dropna_count}")

    # ---------------------------
    # Step 2: Clean String/Text Columns
    # ---------------------------
    text_cols = ["State", "Sex", "GeneralHealth", "SmokerStatus", "ECigaretteUsage",
                 "RaceEthnicityCategory", "AgeCategory", "AlcoholDrinkers"]
    for col_name in text_cols:
        silver_df = silver_df.withColumn(col_name, upper(trim(col(col_name))))
    silver_df = silver_df.withColumn("RaceEthnicityCategory", split(col("RaceEthnicityCategory"), ",")[0])
    log_step("TEXT_CLEAN", "INFO", "Standardized string columns")

    # ---------------------------
    # Step 3: Standardize Binary (Yes/No) Columns
    # ---------------------------
    binary_cols = ["HadHeartAttack", "HadAngina", "HadStroke", "HadAsthma", "HadSkinCancer",
                   "HadCOPD", "HadDepressiveDisorder", "HadKidneyDisease", "HadArthritis",
                   "HadDiabetes", "DeafOrHardOfHearing", "BlindOrVisionDifficulty",
                   "DifficultyConcentrating", "DifficultyWalking", "DifficultyDressingBathing",
                   "DifficultyErrands", "AlcoholDrinkers", "HIVTesting", "FluVaxLast12",
                   "PneumoVaxEver", "TetanusLast10Tdap", "HighRiskLastYear", "CovidPos"]
    pre_null_counts = null_count(silver_df, binary_cols)
    for col_name in binary_cols:
        silver_df = silver_df.withColumn(
            col_name,
            when(upper(col(col_name)) == "YES", 1)
            .when(upper(col(col_name)) == "NO", 0)
            .otherwise(None)
        )
    post_null_counts = null_count(silver_df, binary_cols)
    log_step("BINARY_STANDARDIZE", "INFO",
             f"Null counts before: {pre_null_counts}, after: {post_null_counts}")

    # ---------------------------
    # Step 4–7: Numeric Conversion, Null Handling
    # ---------------------------
    numeric_cols = ["HeightInMeters", "WeightInKilograms", "BMI"]
    for col_name in numeric_cols:
        silver_df = silver_df.withColumn(col_name, col(col_name).cast(DoubleType()))

    before_filter_count = silver_df.count()
    silver_df = silver_df.filter(
        (col("HeightInMeters") > 0) & (col("WeightInKilograms") > 0) & (col("BMI") > 0)
    )
    after_filter_count = silver_df.count()

    silver_df = silver_df.withColumn("BMI", round(col("BMI"), 2))
    silver_df = silver_df.withColumn("HeightInMeters", round(col("HeightInMeters"), 2))
    silver_df = silver_df.withColumn("WeightInKilograms", round(col("WeightInKilograms"), 1))

    fill_na_dict = {
        "GeneralHealth": "MISSING",
        "SmokerStatus": "UNKNOWN",
        "ECigaretteUsage": "UNKNOWN",
        "RaceEthnicityCategory": "OTHER",
        "LastCheckupTime": "UNKNOWN",
        "RemovedTeeth": "UNKNOWN",
        "CovidPos": 0,
        "HadDiabetes": 0
    }
    silver_df = silver_df.fillna(fill_na_dict)
    log_step("NUMERIC_CLEAN", "INFO",
             f"Rows before numeric filter: {before_filter_count}, after filter: {after_filter_count}. "
             f"Filled NAs: {fill_na_dict}")

    # ---------------------------
    # Step 9: Load Hospital Dimension Table
    # ---------------------------
    hospital_df = spark.read.table("processed_outputs.bronze.hospital")
    hospital_df = hospital_df.drop("Location")
    hospital_df = hospital_df.withColumn("State", upper(trim(col("State"))))
    hospital_df = hospital_df.fillna("NOT AVAILABLE")
    hospital_df = hospital_df.withColumn(
        "State",
        when(col("State").isNull() | (trim(col("State")) == ""), lit("UNKNOWN")).otherwise(col("State"))
    )
    log_step("LOAD_HOSPITAL", "INFO", "Hospital dimension table loaded & cleaned")

    # ---------------------------
    # Step 11: Join Patients with Hospital Info
    # ---------------------------
    enriched_df = silver_df.join(hospital_df, silver_df.State == hospital_df.State, "inner") \
                           .drop(hospital_df.State)
    enriched_df = enriched_df.withColumn(
        "Hospital_Info_Available",
        when(col("Provider_ID") != "NOT AVAILABLE", lit(1)).otherwise(lit(0))
    )
    log_step("JOIN", "INFO", f"Joined patients with hospitals. Final rows: {enriched_df.count()}")

    # ---------------------------
    # Step 12: Drop Unwanted Columns
    # ---------------------------
    drop_list = ["TetanusLast10Tdap", "Location", "Safety_of_care_national_comparison"]
    enriched_df = enriched_df.drop(*drop_list)
    log_step("DROP_COLUMNS", "INFO", f"Dropped columns: {drop_list}")

    # ---------------------------
    # Step 13: Save to Silver Layer
    # ---------------------------
    enriched_df.write.format("delta") \
        .mode("overwrite") \
        .option("inferSchema", "true") \
        .option("overwriteSchema", "true") \
        .partitionBy("State") \
        .saveAsTable("processed_outputs.silver.processed")
    log_step("SAVE_SILVER", "INFO", "Silver Layer data enriched & saved")

    log_step("END", "SUCCESS", "Pipeline completed successfully ✅")

except Exception as e:
    log_step("ERROR", "FAIL", traceback.format_exc())
    raise