In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, when, round, regexp_replace, trim
from pyspark.ml.feature import StringIndexer, OneHotEncoder, VectorAssembler, StandardScaler
from pyspark.ml import Pipeline


# =====================================================================
# CONFIG
# =====================================================================
INPUT_DELTA_TABLE = "workspace.ml_credit_risk.credit_risk_data"


# =====================================================================
# SPARK INITIALIZATION
# =====================================================================
def initialize_spark():
    if 'spark' in globals() and isinstance(globals()['spark'], SparkSession):
        return globals()['spark']
    return SparkSession.builder.appName("CreditRiskPreprocessing").getOrCreate()


# =====================================================================
# STEP 1: DATA INGESTION
# =====================================================================
def ingest_data(spark: SparkSession, table_name: str):
    print(f"üì• Loading: {table_name}")
    
    credit_df = spark.read.format("delta").table(table_name)

    feature_cols = [
        'checking_balance', 'months_loan_duration', 'credit_history', 'purpose',
        'amount', 'savings_balance', 'employment_duration', 'percent_of_income',
        'years_at_residence', 'age', 'other_credit', 'housing',
        'existing_loans_count', 'job', 'dependents', 'phone'
    ]
    
    df = credit_df.select(*feature_cols, col("default").alias("label"))
    print(f"‚úÖ Loaded Rows: {df.count():,}")
    return df


# =====================================================================
# STEP 2: UNIT CLEANUP (KEEP SAME COLUMN NAMES)
# =====================================================================
def cleanup_units(df):
    print("üßπ Cleaning units...")

    df = df.withColumn("checking_balance", trim(regexp_replace(col("checking_balance"), r"\s*DM\s*$", "")))
    df = df.withColumn("savings_balance", trim(regexp_replace(col("savings_balance"), r"\s*DM\s*$", "")))
    df = df.withColumn("employment_duration", trim(regexp_replace(col("employment_duration"), r"\s*years?\s*$", "")))

    # Replace blanks with "unknown"
    df = df.replace("", "unknown")
    
    print("‚úÖ Unit cleanup done")
    return df


# =====================================================================
# STEP 3: DATA PREPARATION
# =====================================================================
def prepare_data(df):
    print("üìä Preparing data...")

    df = df.withColumn("label", when(col("label") == "yes", 1.0).otherwise(0.0))

    df = df.withColumn(
        "monthly_income",
        round(
            when(
                (col("percent_of_income") > 0) & (col("months_loan_duration") > 0),
                (col("amount") / col("months_loan_duration")) * (100 / col("percent_of_income"))
            ).otherwise(None),
            2
        )
    )

    print("‚úÖ Preparation complete")
    return df


# =====================================================================
# STEP 4: ORDINAL ENCODING
# =====================================================================
def ordinal_encoding(df):
    print("üî¢ Applying ordinal encoding...")

    ordinal_config = {
        'checking_balance': ['< 0', '1 - 200', '> 200', 'unknown'],
        'savings_balance': ['< 100', '100 - 500', '500 - 1000', '> 1000', 'unknown'],
        'employment_duration': ['unemployed', '< 1', '1 - 4', '4 - 7', '> 7', 'unknown'],
        'credit_history': ['critical', 'poor', 'good', 'very good', 'perfect']
    }

    for col_name, categories in ordinal_config.items():
        expr = None
        for idx, cat in enumerate(categories):
            expr = when(col(col_name) == cat, float(idx)) if expr is None else expr.when(col(col_name) == cat, float(idx))
        df = df.withColumn(col_name, expr.otherwise(float(len(categories))))

    print("‚úÖ Ordinal encoding done")
    return df


# =====================================================================
# STEP 5: ONE-HOT ENCODING (FIXED FOR DATABRICKS REPOS)
# =====================================================================
def onehot_encoding(df):
    print("üî• Applying Repo-safe OneHot Encoding...")

    nominal_cols = ['purpose', 'other_credit', 'housing', 'job', 'phone']

    # FIX: Apply transformations sequentially without Pipeline
    # This avoids Databricks Repos security restrictions
    
    # Step 1: StringIndexer for each column
    for col_name in nominal_cols:
        indexer = StringIndexer(
            inputCol=col_name,
            outputCol=f"{col_name}_index",
            handleInvalid="keep"
        )
        df = indexer.fit(df).transform(df)
    
    # Step 2: OneHotEncoder for all indexed columns
    index_cols = [f"{c}_index" for c in nominal_cols]
    vec_cols = [f"{c}_vec" for c in nominal_cols]
    
    encoder = OneHotEncoder(
        inputCols=index_cols,
        outputCols=vec_cols,
        dropLast=True
    )
    df = encoder.fit(df).transform(df)
    
    # Remove original nominal + index columns, keep only vectors
    df = df.drop(*nominal_cols, *index_cols)

    print("‚úÖ One-hot encoding complete")
    return df


# =====================================================================
# STEP 6: STANDARD SCALING
# =====================================================================
def apply_standard_scaling(df):
    print("üìè Scaling features...")

    feature_cols = [c for c in df.columns if c != "label"]

    # Step 1: Assemble features
    assembler = VectorAssembler(inputCols=feature_cols, outputCol="unscaled_features")
    df = assembler.transform(df)
    
    # Step 2: Scale features
    scaler = StandardScaler(
        inputCol="unscaled_features",
        outputCol="features",
        withMean=True,
        withStd=True
    )
    scaler_model = scaler.fit(df)
    df = scaler_model.transform(df).select("features", "label")

    print("‚úÖ Scaling complete")
    return df, scaler_model


# =====================================================================
# MAIN EXECUTION
# =====================================================================
if __name__ == "__main__":
    print("\nüöÄ CREDIT RISK PIPELINE STARTING\n")
    
    spark = initialize_spark()

    df = ingest_data(spark, INPUT_DELTA_TABLE)
    df = cleanup_units(df)
    df = prepare_data(df)
    df = ordinal_encoding(df)
    df = onehot_encoding(df)
    processed_df, scaler_model = apply_standard_scaling(df)

    print("\nüéâ Pipeline complete!")
    print(f"Rows: {processed_df.count():,}")
    processed_df.printSchema()
    processed_df.show(5)