In [0]:
# Databricks notebook: MAS610_Gold_FactExposure
# Purpose: join & derive MAS610 exposure metrics (Exercise 2)

from pyspark.sql import SparkSession, functions as F
from datetime import datetime

spark = SparkSession.builder.appName("MAS610_Gold_FactExposure").getOrCreate()

# ---------------------------------------------------------------------
# Parameters
# ---------------------------------------------------------------------
dbutils.widgets.text("silver_dir", "/mnt/silver/mas610_ex2")
dbutils.widgets.text("gold_dir", "/mnt/gold/mas610")
silver_dir = dbutils.widgets.get("silver_dir")
gold_dir   = dbutils.widgets.get("gold_dir")

# ---------------------------------------------------------------------
# Read Silver tables
# ---------------------------------------------------------------------
accounts = spark.read.format("delta").load(f"{silver_dir}/accounts_silver.delta")
loans    = spark.read.format("delta").load(f"{silver_dir}/loans_silver.delta")
collat   = spark.read.format("delta").load(f"{silver_dir}/collateral_silver.delta")

# ---------------------------------------------------------------------
# Join & Enrich ‚Üí MAS610 Fact Exposure
# ---------------------------------------------------------------------
fact_exposure = (
    loans
      .join(accounts, "customer_id", "left")
      .join(collat, "loan_id", "left")
      .withColumn("risk_weight",
          F.when(F.col("collateral_value").isNull() | (F.col("collateral_value")==0), 100.0)
           .otherwise(50.0))
      .withColumn("EAD", F.col("notional"))
      .withColumn("RWA", F.round(F.col("EAD") * F.col("risk_weight") / 100, 2))
      .withColumn("secured_flag", F.when(F.col("collateral_value") > 0, 1).otherwise(0))
      .withColumn("reporting_date", F.lit(datetime.today().strftime("%Y-%m-%d")))
)

# ---------------------------------------------------------------------
# Remove duplicate columns before writing
# ---------------------------------------------------------------------
fact_exposure = fact_exposure.drop(
    *[c for c in fact_exposure.columns if c.startswith("load_timestamp")]
)
fact_exposure = fact_exposure.withColumn("load_timestamp", F.current_timestamp())

# ---------------------------------------------------------------------
# Basic DQ validation (assertion)
# ---------------------------------------------------------------------
null_customers = fact_exposure.filter(F.col("customer_id").isNull()).count()
bad_weights    = fact_exposure.filter(~F.col("risk_weight").between(0,150)).count()
assert null_customers==0 and bad_weights==0, "‚ùå DQ validation failed!"


# ---------------------------------------------------------------------
# Persist Gold outputs
# ---------------------------------------------------------------------
fact_exposure.write.mode("overwrite").format("delta").save(f"{gold_dir}/fact_exposure.delta")
fact_exposure.write.mode("overwrite").json(f"{gold_dir}/json")
fact_exposure.write.mode("overwrite").parquet(f"{gold_dir}/parquet")

# ---------------------------------------------------------------------
# MAS610 Report View
# ---------------------------------------------------------------------
fact_exposure.createOrReplaceTempView("fact_exposure")
mas610_report = spark.sql("""
SELECT customer_id AS Customer_ID, loan_id AS Loan_ID, account_type AS Account_Type,
       branch_code AS Branch_Code, region AS Region, currency AS Currency,
       notional AS Exposure_Amount, collateral_value AS Collateral_Value,
       risk_weight AS Risk_Weight_Pct, EAD AS Exposure_at_Default, RWA AS Risk_Weighted_Asset,
       secured_flag AS Secured_Flag, reporting_date AS Reporting_Date
FROM fact_exposure ORDER BY Customer_ID, Loan_ID
""")
mas610_report.write.mode("overwrite").parquet(f"{gold_dir}/MAS610_Report.parquet")

# ---------------------------------------------------------------------
# Optional Summary
# ---------------------------------------------------------------------
summary = (
    mas610_report.groupBy("Currency")
      .agg(F.countDistinct("Loan_ID").alias("Loan_Count"),
           F.sum("Exposure_at_Default").alias("Total_EAD"),
           F.sum("Risk_Weighted_Asset").alias("Total_RWA"),
           F.round(F.sum("Risk_Weighted_Asset")/F.sum("Exposure_at_Default")*100,2).alias("Avg_Risk_Pct"))
)
summary.write.mode("overwrite").csv(f"{gold_dir}/summary_currency.csv", header=True)

# ---------------------------------------------------------------------
# ‚úÖ Extended DQ Results Logging (Optional but recommended)
# ---------------------------------------------------------------------
dq_rules = []
dq_rules.append(("DQ_01", "Null_CustomerID",
                 fact_exposure.filter(F.col("customer_id").isNull()).count()))
dq_rules.append(("DQ_02", "Invalid_RiskWeight",
                 fact_exposure.filter(~F.col("risk_weight").between(0,150)).count()))
dq_rules.append(("DQ_03", "Negative_Notional",
                 fact_exposure.filter(F.col("notional") < 0).count()))

dq_df = spark.createDataFrame(dq_rules, ["Rule_ID", "Rule_Name", "Failed_Count"]) \
             .withColumn("Total_Records", F.lit(fact_exposure.count())) \
             .withColumn("Pass_Flag", F.when(F.col("Failed_Count")==0,"PASS").otherwise("FAIL")) \
             .withColumn("DQ_Run_TS", F.current_timestamp())

dq_df.write.mode("overwrite").format("delta").save(f"{gold_dir}/dq_results.delta")

print("üíæ DQ results written to dq_results.delta")
display(dq_df)

print("üèÅ Gold layer build & MAS610 calculation complete.")
