In [0]:
%run ./1_config.py

In [0]:
import os, importlib
try:
    dbutils.widgets.dropdown("ENV", "dev", ["dev", "qa"], "Environment")
    dbutils.widgets.dropdown("STORAGE_ACCOUNT", "trafficsa2", ["trafficsa2", "trafficsaqa"], "Storage account")
    dbutils.widgets.text("METASTORE_ACCOUNT", "trafficsa2", "Metastore account")
    os.environ["ENV"] = dbutils.widgets.get("ENV").strip().lower()
    os.environ["STORAGE_ACCOUNT"] = dbutils.widgets.get("STORAGE_ACCOUNT").strip()
    os.environ["METASTORE_ACCOUNT"] = (dbutils.widgets.get("METASTORE_ACCOUNT") or os.environ["STORAGE_ACCOUNT"]).strip()
except NameError:
    pass


In [0]:
# Databricks notebook source
# MAGIC %run ./1_config.py
# MAGIC %run ./2_setup.py

# COMMAND ----------

from pyspark.sql import functions as F, types as T
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql.utils import AnalysisException
import re

# --------------------------------------------------------------------------------------
# Gold Builder
# --------------------------------------------------------------------------------------

class GoldBuilder:
    """
    Builds Gold star-schema tables from Silver:
      - fact_traffic_15min   (Region/Site/Detector x 15-minute interval)
      - fact_daily_summary   (Detector x Day)
    Also keeps:
      - dim_time             (Date, Hour, Year, Month, DayOfWeek, WeekdayFlag)
      - dim_detector         (NB_DETECTOR, NB_SCATS_SITE, NM_REGION, suburb via region_lookup)

    Assumptions in Silver:
      NB_SCATS_SITE, NB_DETECTOR, QT_INTERVAL_COUNT, V00..V95,
      NM_REGION, CT_RECORDS, QT_VOLUME_24HOUR, CT_ALARM_24HOUR,
      ReadingDate (DATE), ReadingTs (TIMESTAMP), NM_REGION_NORM, PartitionDate, load_time, source_file
    """

    def __init__(self, conf_obj: "Config" = None):
        self.conf = conf_obj or conf
        self.spark: SparkSession = SparkSession.getActiveSession() or SparkSession.builder.getOrCreate()

        self.catalog = self.conf.catalog
        self.db_name = self.conf.db_name

        # FQNs
        self.silver_fqn = self.conf.table_fqn(self.conf.silver_table)
        self.dim_time_fqn = f"{self.catalog}.{self.db_name}.dim_time"
        self.dim_detector_fqn = f"{self.catalog}.{self.db_name}.dim_detector"
        self.region_lookup_fqn = f"{self.catalog}.{self.db_name}.region_lookup"

        self.fact_15_fqn = f"{self.catalog}.{self.db_name}.fact_traffic_15min"
        self.fact_daily_fqn = f"{self.catalog}.{self.db_name}.fact_daily_summary"

        self._bootstrap_uc()
        self._ensure_fact_tables()

    # ---------------------------- bootstrap ----------------------------
    def _bootstrap_uc(self) -> None:
        self.spark.sql(f"CREATE CATALOG IF NOT EXISTS {self.catalog}")
        self.spark.sql(f"CREATE SCHEMA IF NOT EXISTS {self.catalog}.{self.db_name}")
        self.spark.sql(f"USE CATALOG {self.catalog}")
        self.spark.sql(f"USE {self.db_name}")

    def _ensure_fact_tables(self) -> None:
        # Schemas aligned to what we created in setup
        self.spark.sql(f"""
            CREATE TABLE IF NOT EXISTS {self.fact_15_fqn} (
                TimeKey BIGINT,
                DetectorKey STRING,
                SiteKey STRING,
                RegionKey STRING,
                IntervalStartTime TIMESTAMP,
                Volume BIGINT,
                Year INT,
                Month INT
            )
            USING DELTA
            PARTITIONED BY (Year, Month)
        """)
        self.spark.sql(f"""
            CREATE TABLE IF NOT EXISTS {self.fact_daily_fqn} (
                DateKey INT,
                DetectorKey STRING,
                SiteKey STRING,
                RegionKey STRING,
                ReadingDate DATE,
                TotalVolume BIGINT,
                AlarmCount BIGINT,
                IntervalsWithData INT
            )
            USING DELTA
        """)

    # ---------------------------- helpers ------------------------------
    @staticmethod
    def _hash_key(cols: list) -> F.Column:
        return F.sha2(F.concat_ws("||", *[F.coalesce(F.col(c).cast("string"), F.lit("")) for c in cols]), 256)

    @staticmethod
    def _find_vcols(df: DataFrame) -> list:
        """Return ordered Vxx column names, case-insensitive, sorted by xx (0..95)."""
        found = []
        for c in df.columns:
            m = re.fullmatch(r'v(\d{2})', c, re.IGNORECASE)
            if m:
                found.append((c, int(m.group(1))))
        found = sorted(found, key=lambda x: x[1])
        return [name for name, _ in found]

    @staticmethod
    def _melt_15min(df: DataFrame) -> DataFrame:
        """
        Wide V00..V95 -> long rows with IntervalIndex, Volume, IntervalStartTime.
        Uses posexplode_outer with proper alias of two columns to avoid UDTF alias mismatch.
        """
        vcols = GoldBuilder._find_vcols(df)
        if not vcols:
            raise ValueError("No Vxx (V00..V95) columns found in Silver.")

        # Project identifiers + all Vxx first
        base = df.select("NB_SCATS_SITE", "NB_DETECTOR", "NM_REGION_NORM", "ReadingDate", *vcols)

        # Build exploded rows: (IntervalIndex, Volume)
        exploded = base.select(
            "NB_SCATS_SITE", "NB_DETECTOR", "NM_REGION_NORM", "ReadingDate",
            F.posexplode_outer(
                F.array(*[F.col(c).cast("long") for c in vcols])
            ).alias("IntervalIndex", "Volume")
        )

        # IntervalStartTime = midnight + IntervalIndex*15 minutes (use timestampadd for ANSI safety)
        midnight = F.to_timestamp(F.concat_ws(" ", F.date_format(F.col("ReadingDate"), "yyyy-MM-dd"), F.lit("00:00:00")))
        out = (exploded
               .withColumn("Midnight", midnight)
               .withColumn("IntervalStartTime", F.expr("timestampadd(MINUTE, IntervalIndex * 15, Midnight)"))
               .drop("Midnight"))
        return out

    # ---------------------------- dimensions ---------------------------
    def _upsert_dim_time(self, intervals_df: DataFrame) -> None:
        """Insert missing (Date, Hour) rows into dim_time based on IntervalStartTime."""
        dim = (intervals_df
               .select(
                   F.to_date("IntervalStartTime").alias("Date"),
                   F.hour("IntervalStartTime").alias("Hour"),
                   F.year("IntervalStartTime").alias("Year"),
                   F.month("IntervalStartTime").alias("Month"),
                   F.date_format("IntervalStartTime","E").alias("DayOfWeek"),
                   F.when(F.dayofweek("IntervalStartTime").isin(1,7), F.lit(False))  # Sun=1, Sat=7
                     .otherwise(F.lit(True)).alias("WeekdayFlag")
               )
               .dropDuplicates(["Date","Hour"]))

        self.spark.sql(f"""
            CREATE TABLE IF NOT EXISTS {self.dim_time_fqn} (
                Date DATE, Hour INT, Year INT, Month INT, DayOfWeek STRING, WeekdayFlag BOOLEAN
            ) USING DELTA
        """)
        tgt = self.spark.table(self.dim_time_fqn).select("Date","Hour")
        new_rows = dim.join(tgt, on=["Date","Hour"], how="left_anti")
        if new_rows.count() > 0:
            new_rows.write.mode("append").format("delta").saveAsTable(self.dim_time_fqn)

    def _upsert_dim_detector(self, silver_df: DataFrame) -> None:
        """Ensure dim_detector has latest detector/site/region + suburb via region_lookup."""
        self.spark.sql(f"""
            CREATE TABLE IF NOT EXISTS {self.dim_detector_fqn} (
                NB_DETECTOR INT, NB_SCATS_SITE INT, NM_REGION STRING, suburb STRING
            ) USING DELTA
        """)
        base = (silver_df
                .select(
                    F.col("NB_DETECTOR").cast("int").alias("NB_DETECTOR"),
                    F.col("NB_SCATS_SITE").cast("int").alias("NB_SCATS_SITE"),
                    F.col("NM_REGION_NORM").alias("NM_REGION")
                ).dropDuplicates())

        # Left join lookup to fill suburb
        try:
            lk = self.spark.table(self.region_lookup_fqn).select(
                F.upper(F.trim(F.col("NM_REGION"))).alias("NM_REGION"),
                F.col("SUBURB").alias("suburb")
            )
            base = base.join(lk, "NM_REGION", "left")
        except AnalysisException:
            base = base.withColumn("suburb", F.lit(None).cast("string"))

        tgt = self.spark.table(self.dim_detector_fqn).select("NB_DETECTOR","NB_SCATS_SITE","NM_REGION").dropDuplicates()
        new_rows = base.join(tgt, on=["NB_DETECTOR","NB_SCATS_SITE","NM_REGION"], how="left_anti")
        if new_rows.count() > 0:
            new_rows.write.mode("append").format("delta").saveAsTable(self.dim_detector_fqn)

    # ---------------------------- facts -------------------------------
    def _build_fact_15min(self, silver_df: DataFrame, since_load_time: str = None, full_rebuild: bool = False) -> int:
        src = silver_df
        if since_load_time and not full_rebuild:
            src = src.where(F.col("load_time") >= F.to_timestamp(F.lit(since_load_time)))

        melted = self._melt_15min(src).cache()

        fact_src = (melted
                    .withColumn("RegionKey", self._hash_key(["NM_REGION_NORM"]))
                    .withColumn("SiteKey",   self._hash_key(["NB_SCATS_SITE"]))
                    .withColumn("DetectorKey", self._hash_key(["NB_SCATS_SITE","NB_DETECTOR"]))
                    .withColumn("TimeKey", F.date_format("IntervalStartTime","yyyyMMddHHmm").cast("bigint"))
                    .withColumn("Year", F.year("IntervalStartTime"))
                    .withColumn("Month", F.month("IntervalStartTime"))
                    .select("TimeKey","DetectorKey","SiteKey","RegionKey","IntervalStartTime","Volume","Year","Month"))

        if full_rebuild:
            self.spark.sql(f"TRUNCATE TABLE {self.fact_15_fqn}")
            fact_src.write.mode("append").format("delta").partitionBy("Year","Month").saveAsTable(self.fact_15_fqn)
            rows = fact_src.count()
            melted.unpersist()
            return rows
        else:
            fact_src.createOrReplaceTempView("__fact15_src")
            self.spark.sql(f"""
                MERGE INTO {self.fact_15_fqn} AS tgt
                USING __fact15_src AS src
                  ON tgt.TimeKey = src.TimeKey AND tgt.DetectorKey = src.DetectorKey
                WHEN MATCHED THEN UPDATE SET *
                WHEN NOT MATCHED THEN INSERT *
            """)
            rows = melted.count()
            melted.unpersist()
            return rows

    def _build_fact_daily(self, silver_df: DataFrame, since_load_time: str = None, full_rebuild: bool = False) -> int:
        src = silver_df
        if since_load_time and not full_rebuild:
            src = src.where(F.col("load_time") >= F.to_timestamp(F.lit(since_load_time)))

        daily_src = (src
                     .select(
                         F.col("NB_SCATS_SITE").cast("int").alias("NB_SCATS_SITE"),
                         F.col("NB_DETECTOR").cast("int").alias("NB_DETECTOR"),
                         F.col("NM_REGION_NORM").alias("NM_REGION"),
                         F.col("ReadingDate").alias("ReadingDate"),
                         F.col("CT_RECORDS").cast("int").alias("IntervalsWithData"),
                         F.col("QT_VOLUME_24HOUR").cast("long").alias("TotalVolume"),
                         F.col("CT_ALARM_24HOUR").cast("long").alias("AlarmCount")
                     )
                     .dropna(subset=["ReadingDate","NB_DETECTOR","NB_SCATS_SITE"]))

        fact_src = (daily_src
                    .withColumn("RegionKey", self._hash_key(["NM_REGION"]))
                    .withColumn("SiteKey",   self._hash_key(["NB_SCATS_SITE"]))
                    .withColumn("DetectorKey", self._hash_key(["NB_SCATS_SITE","NB_DETECTOR"]))
                    .withColumn("DateKey", F.date_format("ReadingDate","yyyyMMdd").cast("int"))
                    .select("DateKey","DetectorKey","SiteKey","RegionKey","ReadingDate","TotalVolume","AlarmCount","IntervalsWithData"))

        if full_rebuild:
            self.spark.sql(f"TRUNCATE TABLE {self.fact_daily_fqn}")
            fact_src.write.mode("append").format("delta").saveAsTable(self.fact_daily_fqn)
            return fact_src.count()
        else:
            fact_src.createOrReplaceTempView("__factdaily_src")
            self.spark.sql(f"""
                MERGE INTO {self.fact_daily_fqn} AS tgt
                USING __factdaily_src AS src
                  ON tgt.DateKey = src.DateKey AND tgt.DetectorKey = src.DetectorKey
                WHEN MATCHED THEN UPDATE SET *
                WHEN NOT MATCHED THEN INSERT *
            """)
            return fact_src.count()

    # ---------------------------- public APIs -------------------------
    def rebuild_all(self) -> None:
        """
        Full rebuild: refresh dimensions, truncate facts, load everything from Silver.
        """
        silver = self.spark.table(self.silver_fqn)

        # Dimensions
        melted = self._melt_15min(silver)
        self._upsert_dim_time(melted)
        self._upsert_dim_detector(silver)

        # Facts
        rows15 = self._build_fact_15min(silver, full_rebuild=True)
        rowsD  = self._build_fact_daily(silver, full_rebuild=True)
        print(f"✅ Full Gold rebuild complete: 15min={rows15}, daily={rowsD}")

    def incremental_upsert(self, since_load_time: str) -> None:
        """
        Process only Silver rows with load_time >= since_load_time (ISO timestamp).
        Keeps dimensions up to date.
        """
        silver = self.spark.table(self.silver_fqn)

        # Dimensions (only for new intervals)
        melted = self._melt_15min(silver.where(F.col("load_time") >= F.to_timestamp(F.lit(since_load_time))))
        self._upsert_dim_time(melted)
        self._upsert_dim_detector(silver.where(F.col("load_time") >= F.to_timestamp(F.lit(since_load_time))))

        rows15 = self._build_fact_15min(silver, since_load_time=since_load_time, full_rebuild=False)
        rowsD  = self._build_fact_daily(silver, since_load_time=since_load_time, full_rebuild=False)
        print(f"✅ Incremental Gold complete: 15min={rows15}, daily={rowsD}")

    def validate(self) -> None:
        f15 = self.spark.table(self.fact_15_fqn)
        fd  = self.spark.table(self.fact_daily_fqn)
        try:
            parts = self.spark.sql(f"SHOW PARTITIONS {self.fact_15_fqn}").limit(5).collect()
            part_preview = ", ".join([r[0] for r in parts]) if parts else "(no partitions yet)"
        except Exception:
            part_preview = "(partitions not listed)"
        print(f"🔎 {self.fact_15_fqn}: {f15.count()} rows, partitions preview: {part_preview}")
        print(f"🔎 {self.fact_daily_fqn}: {fd.count()} rows")

# --------------------------------------------------------------------------------------
# Example usage (run as a notebook cell)
# --------------------------------------------------------------------------------------

GB = GoldBuilder(conf)
# Option A: full rebuild (recommended after changes)
GB.rebuild_all()

# Option B: only new since a timestamp
# GB.incremental_upsert("2025-05-01T00:00:00")

GB.validate()

In [0]:
# ---------- orchestration ----------
conf  = Config()
spark = SparkSession.getActiveSession() or SparkSession.builder.getOrCreate()
_bootstrap_uc(spark, conf.catalog, conf.db_name)

silver_fact_fqn   = resolve_silver_fqn(conf)
region_lookup_fqn = conf.table_fqn(conf.region_lookup)

# targets
t_region_hourly   = conf.table_fqn("traffic_gold_region_hourly")
t_detector_hourly = conf.table_fqn("traffic_gold_detector_hourly")
t_region_monthly  = conf.table_fqn("traffic_gold_region_monthly")
t_detector_cong   = conf.table_fqn("traffic_gold_detector_congestion")
t_gold_errors     = conf.table_fqn("traffic_gold_errors")  # kept for parity (not used in MERGE path)

ensure_region_lookup_coverage(spark, silver_fact_fqn, region_lookup_fqn, backfill_unknown=True)

# precreate (first run) so MERGE targets exist
precreate_gold_tables(
    spark=spark,
    silver_fact_fqn=silver_fact_fqn,
    region_lookup_fqn=region_lookup_fqn,
    t_region_hourly=t_region_hourly,
    t_detector_hourly=t_detector_hourly,
    t_region_monthly=t_region_monthly,
    t_detector_cong=t_detector_cong,
    dev_mode=DEV_MODE
)

# CDF streaming source from Silver (ignore deletes; use postimages)
streaming_df = (
    spark.readStream
         .format("delta")
         .option("readChangeFeed", "true")
         .option("startingVersion", CDF_FROM_VERSION)   # or .option("startingTimestamp","2025-01-01")
         .table(silver_fact_fqn)
         .where(col("_change_type").isin("insert","update_postimage"))
)

checkpoint_dir = f"{conf.checkpoint_base}/gold_cdf/{silver_fact_fqn.replace('.','_')}"

foreach_batch = make_foreach_batch_cdf_merge(
    silver_fact_fqn=silver_fact_fqn,
    region_lookup_fqn=region_lookup_fqn,
    thresholds=THRESHOLDS,
    t_region_hourly=t_region_hourly,
    t_detector_hourly=t_detector_hourly,
    t_region_monthly=t_region_monthly,
    t_detector_cong=t_detector_cong
)

q = (streaming_df.writeStream
     .foreachBatch(foreach_batch)
     .option("checkpointLocation", checkpoint_dir)
     .outputMode("update")   # foreachBatch ignores this for MERGE; harmless
     .trigger(once=True)
     .start())
q.awaitTermination()

# --- smoke/validate (optional) ---
for t in [t_region_hourly, t_detector_hourly, t_region_monthly, t_detector_cong]:
    try:
        print(t, spark.table(t).count(), "rows")
    except Exception as e:
        print("Missing:", t, e)
print("✅ gold (CDF + MERGE) completed.")