In [0]:
%run ./1_config.py

In [0]:
import os, importlib
try:
    dbutils.widgets.dropdown("ENV", "dev", ["dev", "qa"], "Environment")
    dbutils.widgets.dropdown("STORAGE_ACCOUNT", "trafficsa2", ["trafficsa2", "trafficsaqa"], "Storage account")
    dbutils.widgets.text("METASTORE_ACCOUNT", "trafficsa2", "Metastore account")
    os.environ["ENV"] = dbutils.widgets.get("ENV").strip().lower()
    os.environ["STORAGE_ACCOUNT"] = dbutils.widgets.get("STORAGE_ACCOUNT").strip()
    os.environ["METASTORE_ACCOUNT"] = (dbutils.widgets.get("METASTORE_ACCOUNT") or os.environ["STORAGE_ACCOUNT"]).strip()
except NameError:
    pass


In [0]:
from typing import Optional, List
from pyspark.sql import SparkSession, DataFrame
from pyspark.sql.utils import AnalysisException
from pyspark.sql import functions as F
from pyspark.sql.functions import (
    col, when, year, month, upper, trim, coalesce, lit,
    sum as sum_, count as count_, min as min_, max as max_
)

DEV_MODE   = True                    # dev = allow precreate overwrite; prod = set False
THRESHOLDS = {"low": 40, "med": 75, "high": 110}
CDF_FROM_VERSION   = "0"             # or set startingTimestamp instead (e.g. "2025-01-01T00:00:00Z")

def _bootstrap_uc(spark: SparkSession, catalog: str, db_name: str) -> None:
    spark.sql(f"CREATE CATALOG IF NOT EXISTS {catalog}")
    spark.sql(f"CREATE SCHEMA IF NOT EXISTS {catalog}.{db_name}")
    spark.sql(f"USE CATALOG {catalog}")
    spark.sql(f"USE {db_name}")

def _norm_nm_region_expr(c: str) -> F.Column:
    return upper(trim(col(c)))

def resolve_silver_fqn(conf: Config, explicit: Optional[str] = None) -> str:
    spark = SparkSession.getActiveSession() or SparkSession.builder.getOrCreate()
    catalog, db = conf.catalog, conf.db_name
    candidates: List[str] = []
    if explicit: candidates.append(explicit)
    candidates += [
        conf.table_fqn(conf.silver_table),              # canonical (traffic_silver_events)
        f"{catalog}.{db}.traffic_silver_events",
        f"{catalog}.{db}.traffic_silver_envents",       # common typo
        f"{catalog}.{db}.traffic_silver_fact",          # legacy
        f"{catalog}.{db}.traffic_silver_event",         # singular
    ]
    tried = set()
    for fqn in candidates:
        if fqn in tried: 
            continue
        tried.add(fqn)
        try:
            spark.table(fqn).limit(1).count()
            print(f"✅ Using Silver source: {fqn}")
            return fqn
        except AnalysisException:
            pass
    print("❌ Could not find a Silver table. Tried:", sorted(tried))
    spark.sql(f"SHOW TABLES IN {catalog}.{db} LIKE 'traffic_silver*'").show(truncate=False)
    raise RuntimeError("Silver table not found. Align your config or rename the table.")

def ensure_region_lookup_coverage(
    spark: SparkSession,
    silver_fact_fqn: str,
    region_lookup_fqn: str,
    backfill_unknown: bool = True
) -> None:
    try:
        spark.table(region_lookup_fqn).limit(1).count()
    except AnalysisException:
        spark.sql(f"""
            CREATE TABLE IF NOT EXISTS {region_lookup_fqn} (
                NM_REGION STRING,
                SUBURB   STRING
            )
            USING DELTA
        """)
    lk_norm = (spark.table(region_lookup_fqn)
        .select(_norm_nm_region_expr("NM_REGION").alias("NM_REGION"), col("SUBURB"))
        .dropDuplicates(["NM_REGION"]))
    silver_regions = (spark.table(silver_fact_fqn)
        .select(_norm_nm_region_expr("NM_REGION").alias("NM_REGION"))
        .distinct())
    missing = silver_regions.join(lk_norm.select("NM_REGION"), "NM_REGION", "left_anti")
    miss_cnt = missing.count()
    if miss_cnt > 0 and backfill_unknown:
        (missing.withColumn("SUBURB", lit("Unknown"))
                .write.format("delta").mode("append").saveAsTable(region_lookup_fqn))
        print(f"✅ Backfilled {miss_cnt} missing NM_REGION in region_lookup as 'Unknown'")
    # normalize & dedupe table
    lk_final = (spark.table(region_lookup_fqn)
        .select(_norm_nm_region_expr("NM_REGION").alias("NM_REGION"), col("SUBURB"))
        .dropDuplicates(["NM_REGION"]))
    lk_final.write.format("delta").mode("overwrite").option("overwriteSchema","true").saveAsTable(region_lookup_fqn)

# --------- create empty targets (first time) ----------
def _create_or_overwrite_empty(df0: DataFrame, table_fqn: str, dev_mode: bool) -> None:
    spark = df0.sparkSession
    if dev_mode:
        df0.write.format("delta").mode("overwrite").option("overwriteSchema", "true").saveAsTable(table_fqn)
    else:
        try:
            spark.table(table_fqn).limit(1).count()
        except AnalysisException:
            df0.write.format("delta").mode("overwrite").option("overwriteSchema", "true").saveAsTable(table_fqn)

def precreate_gold_tables(
    spark: SparkSession,
    silver_fact_fqn: str,
    region_lookup_fqn: str,
    t_region_hourly: str,
    t_detector_hourly: str,
    t_region_monthly: str,
    t_detector_cong: str,
    dev_mode: bool
) -> None:
    src0 = spark.table(silver_fact_fqn).limit(0)
    fact0 = src0.select(
        _norm_nm_region_expr("NM_REGION").alias("NM_REGION"),
        "NB_SCATS_SITE", "NB_DETECTOR", "Interval_EndTime",
        coalesce(col("Volume").cast("long"), lit(0)).alias("Volume")
    )
    lk0 = (spark.table(region_lookup_fqn)
            .select(_norm_nm_region_expr("NM_REGION").alias("NM_REGION"), col("SUBURB"))
            .dropDuplicates(["NM_REGION"]))
    df0 = (fact0.join(lk0, ["NM_REGION"], "left")
                 .withColumn("SUBURB", coalesce(col("SUBURB"), lit("Unknown"))))
    detector_hourly0 = (df0.groupBy("NM_REGION","SUBURB","NB_SCATS_SITE","NB_DETECTOR","Interval_EndTime")
                          .agg(F.sum("Volume").alias("Hourly_Volume")))
    region_hourly0   = (detector_hourly0.groupBy("NM_REGION","SUBURB","Interval_EndTime")
                                         .agg(F.sum("Hourly_Volume").alias("Hourly_Volume")))
    region_monthly0  = (region_hourly0.withColumn("Year", F.year("Interval_EndTime"))
                                        .withColumn("Month", F.month("Interval_EndTime"))
                                        .groupBy("NM_REGION","SUBURB","Year","Month")
                                        .agg(F.sum("Hourly_Volume").alias("Monthly_Volume")))
    detector_cong0   = (detector_hourly0
                        .withColumn("Congestion_Flag", F.lit(False))
                        .withColumn("Congestion_Level", F.lit("Normal")))
    _create_or_overwrite_empty(region_hourly0,   t_region_hourly,   dev_mode)
    _create_or_overwrite_empty(detector_hourly0, t_detector_hourly, dev_mode)
    _create_or_overwrite_empty(region_monthly0,  t_region_monthly,  dev_mode)
    _create_or_overwrite_empty(detector_cong0,   t_detector_cong,   dev_mode)
    print("✅ Pre-created Gold tables (empty schemas).")

# --------- MERGE helpers ----------
def _merge_detector_hourly(spark: SparkSession, df: DataFrame, target: str):
    df.createOrReplaceTempView("upd_dh")
    spark.sql(f"""
    MERGE INTO {target} AS t
    USING upd_dh AS s
    ON  t.NM_REGION = s.NM_REGION
    AND t.SUBURB = s.SUBURB
    AND t.NB_SCATS_SITE = s.NB_SCATS_SITE
    AND t.NB_DETECTOR = s.NB_DETECTOR
    AND t.Interval_EndTime = s.Interval_EndTime
    WHEN MATCHED THEN UPDATE SET
      t.Hourly_Volume = s.Hourly_Volume
    WHEN NOT MATCHED THEN INSERT (
      NM_REGION, SUBURB, NB_SCATS_SITE, NB_DETECTOR, Interval_EndTime, Hourly_Volume
    ) VALUES (
      s.NM_REGION, s.SUBURB, s.NB_SCATS_SITE, s.NB_DETECTOR, s.Interval_EndTime, s.Hourly_Volume
    )
    """)

def _merge_region_hourly(spark: SparkSession, df: DataFrame, target: str):
    df.createOrReplaceTempView("upd_rh")
    spark.sql(f"""
    MERGE INTO {target} AS t
    USING upd_rh AS s
    ON  t.NM_REGION = s.NM_REGION
    AND t.SUBURB = s.SUBURB
    AND t.Interval_EndTime = s.Interval_EndTime
    WHEN MATCHED THEN UPDATE SET
      t.Hourly_Volume = s.Hourly_Volume
    WHEN NOT MATCHED THEN INSERT (
      NM_REGION, SUBURB, Interval_EndTime, Hourly_Volume
    ) VALUES (
      s.NM_REGION, s.SUBURB, s.Interval_EndTime, s.Hourly_Volume
    )
    """)

def _merge_region_monthly(spark: SparkSession, df: DataFrame, target: str):
    df.createOrReplaceTempView("upd_rm")
    spark.sql(f"""
    MERGE INTO {target} AS t
    USING upd_rm AS s
    ON  t.NM_REGION = s.NM_REGION
    AND t.SUBURB = s.SUBURB
    AND t.Year = s.Year
    AND t.Month = s.Month
    WHEN MATCHED THEN UPDATE SET
      t.Monthly_Volume = s.Monthly_Volume
    WHEN NOT MATCHED THEN INSERT (
      NM_REGION, SUBURB, Year, Month, Monthly_Volume
    ) VALUES (
      s.NM_REGION, s.SUBURB, s.Year, s.Month, s.Monthly_Volume
    )
    """)

def _merge_detector_congestion(spark: SparkSession, df: DataFrame, target: str):
    df.createOrReplaceTempView("upd_dc")
    spark.sql(f"""
    MERGE INTO {target} AS t
    USING upd_dc AS s
    ON  t.NM_REGION = s.NM_REGION
    AND t.SUBURB = s.SUBURB
    AND t.NB_SCATS_SITE = s.NB_SCATS_SITE
    AND t.NB_DETECTOR = s.NB_DETECTOR
    AND t.Interval_EndTime = s.Interval_EndTime
    WHEN MATCHED THEN UPDATE SET
      t.Hourly_Volume = s.Hourly_Volume,
      t.Congestion_Flag = s.Congestion_Flag,
      t.Congestion_Level = s.Congestion_Level
    WHEN NOT MATCHED THEN INSERT (
      NM_REGION, SUBURB, NB_SCATS_SITE, NB_DETECTOR, Interval_EndTime, Hourly_Volume, Congestion_Flag, Congestion_Level
    ) VALUES (
      s.NM_REGION, s.SUBURB, s.NB_SCATS_SITE, s.NB_DETECTOR, s.Interval_EndTime, s.Hourly_Volume, s.Congestion_Flag, s.Congestion_Level
    )
    """)

# --------- foreachBatch using CDF (correctness) ----------
def make_foreach_batch_cdf_merge(
    silver_fact_fqn: str,
    region_lookup_fqn: str,
    thresholds: dict,
    t_region_hourly: str,
    t_detector_hourly: str,
    t_region_monthly: str,
    t_detector_cong: str
):
    low_t, med_t, high_t = thresholds["low"], thresholds["med"], thresholds["high"]

    def foreach_batch(batch_df: DataFrame, batch_id: int):
        if batch_df.isEmpty():
            print(f"Batch {batch_id}: no changes.")
            return
        spark = batch_df.sparkSession

        # 1) Identify impacted keys from CDF batch (normalized region)
        changed = (batch_df
            .select(
                _norm_nm_region_expr("NM_REGION").alias("NM_REGION"),
                "NB_SCATS_SITE","NB_DETECTOR","Interval_EndTime"
            )
            .dropDuplicates()
        )
        keys_rh = changed.select("NM_REGION","Interval_EndTime").distinct()
        keys_rm = (keys_rh
            .withColumn("Year", F.year("Interval_EndTime"))
            .withColumn("Month", F.month("Interval_EndTime"))
            .select("NM_REGION","Year","Month").distinct())

        # 2) Pull the latest snapshot for JUST those keys from Silver (ground truth)
        silver = spark.table(silver_fact_fqn).select(
            _norm_nm_region_expr("NM_REGION").alias("NM_REGION"),
            "NB_SCATS_SITE","NB_DETECTOR","Interval_EndTime",
            coalesce(col("Volume").cast("long"), lit(0)).alias("Volume")
        )

        lk = (spark.table(region_lookup_fqn)
                .select(_norm_nm_region_expr("NM_REGION").alias("NM_REGION"), col("SUBURB"))
                .dropDuplicates(["NM_REGION"]))

        # 3) Detector-hour (exact latest values)
        det_imp = (silver.join(changed, ["NM_REGION","NB_SCATS_SITE","NB_DETECTOR","Interval_EndTime"], "inner")
                         .join(lk, ["NM_REGION"], "left")
                         .withColumn("SUBURB", coalesce(col("SUBURB"), lit("Unknown")))
                         .groupBy("NM_REGION","SUBURB","NB_SCATS_SITE","NB_DETECTOR","Interval_EndTime")
                         .agg(F.sum("Volume").alias("Hourly_Volume")))

        # 4) Region-hour (recompute from snapshot for impacted keys)
        region_hourly = (silver.join(keys_rh, ["NM_REGION","Interval_EndTime"], "inner")
                               .groupBy("NM_REGION","Interval_EndTime").agg(F.sum("Volume").alias("Hourly_Volume"))
                               .join(lk, ["NM_REGION"], "left")
                               .withColumn("SUBURB", coalesce(col("SUBURB"), lit("Unknown")))
                               .select("NM_REGION","SUBURB","Interval_EndTime","Hourly_Volume"))

        # 5) Region-month (recompute from snapshot for impacted months)
        region_monthly = (silver
            .withColumn("Year", F.year("Interval_EndTime"))
            .withColumn("Month", F.month("Interval_EndTime"))
            .join(keys_rm, ["NM_REGION","Year","Month"], "inner")
            .groupBy("NM_REGION","Year","Month").agg(F.sum("Volume").alias("Monthly_Volume"))
            .join(lk, ["NM_REGION"], "left")
            .withColumn("SUBURB", coalesce(col("SUBURB"), lit("Unknown")))
            .select("NM_REGION","SUBURB","Year","Month","Monthly_Volume"))

        # 6) Detector congestion (derived from detector_hourly)
        detector_congestion = (det_imp
            .withColumn("Congestion_Flag", F.when(col("Hourly_Volume") > low_t, True).otherwise(False))
            .withColumn("Congestion_Level",
                F.when(col("Hourly_Volume") > high_t, "High")
                 .when(col("Hourly_Volume") > med_t, "Medium")
                 .when(col("Hourly_Volume") > low_t, "Low")
                 .otherwise("Normal"))
        )

        # 7) MERGE into targets (upsert)
        _merge_detector_hourly(spark, det_imp,         t_detector_hourly)
        _merge_region_hourly(spark,   region_hourly,   t_region_hourly)
        _merge_region_monthly(spark,  region_monthly,  t_region_monthly)
        _merge_detector_congestion(spark, detector_congestion, t_detector_cong)

        print(f"Batch {batch_id}: merged {det_imp.count()} detector-hrs; "
              f"{region_hourly.count()} region-hrs; {region_monthly.count()} region-months.")

    return foreach_batch



In [0]:
# ---------- orchestration ----------
conf  = Config()
spark = SparkSession.getActiveSession() or SparkSession.builder.getOrCreate()
_bootstrap_uc(spark, conf.catalog, conf.db_name)

silver_fact_fqn   = resolve_silver_fqn(conf)
region_lookup_fqn = conf.table_fqn(conf.region_lookup)

# targets
t_region_hourly   = conf.table_fqn("traffic_gold_region_hourly")
t_detector_hourly = conf.table_fqn("traffic_gold_detector_hourly")
t_region_monthly  = conf.table_fqn("traffic_gold_region_monthly")
t_detector_cong   = conf.table_fqn("traffic_gold_detector_congestion")
t_gold_errors     = conf.table_fqn("traffic_gold_errors")  # kept for parity (not used in MERGE path)

ensure_region_lookup_coverage(spark, silver_fact_fqn, region_lookup_fqn, backfill_unknown=True)

# precreate (first run) so MERGE targets exist
precreate_gold_tables(
    spark=spark,
    silver_fact_fqn=silver_fact_fqn,
    region_lookup_fqn=region_lookup_fqn,
    t_region_hourly=t_region_hourly,
    t_detector_hourly=t_detector_hourly,
    t_region_monthly=t_region_monthly,
    t_detector_cong=t_detector_cong,
    dev_mode=DEV_MODE
)

# CDF streaming source from Silver (ignore deletes; use postimages)
streaming_df = (
    spark.readStream
         .format("delta")
         .option("readChangeFeed", "true")
         .option("startingVersion", CDF_FROM_VERSION)   # or .option("startingTimestamp","2025-01-01")
         .table(silver_fact_fqn)
         .where(col("_change_type").isin("insert","update_postimage"))
)

checkpoint_dir = f"{conf.checkpoint_base}/gold_cdf/{silver_fact_fqn.replace('.','_')}"

foreach_batch = make_foreach_batch_cdf_merge(
    silver_fact_fqn=silver_fact_fqn,
    region_lookup_fqn=region_lookup_fqn,
    thresholds=THRESHOLDS,
    t_region_hourly=t_region_hourly,
    t_detector_hourly=t_detector_hourly,
    t_region_monthly=t_region_monthly,
    t_detector_cong=t_detector_cong
)

q = (streaming_df.writeStream
     .foreachBatch(foreach_batch)
     .option("checkpointLocation", checkpoint_dir)
     .outputMode("update")   # foreachBatch ignores this for MERGE; harmless
     .trigger(once=True)
     .start())
q.awaitTermination()

# --- smoke/validate (optional) ---
for t in [t_region_hourly, t_detector_hourly, t_region_monthly, t_detector_cong]:
    try:
        print(t, spark.table(t).count(), "rows")
    except Exception as e:
        print("Missing:", t, e)
print("✅ gold (CDF + MERGE) completed.")