In [0]:
%run ./1_config.py

In [0]:
# --- inline params (dev/qa) ---
import os, importlib
try:
    dbutils.widgets.dropdown("ENV", "dev", ["dev", "qa"], "Environment")
    dbutils.widgets.dropdown("STORAGE_ACCOUNT", "trafficsa2", ["trafficsa2", "trafficsaqa"], "Storage account")
    dbutils.widgets.text("METASTORE_ACCOUNT", "trafficsa2", "Metastore account")
    os.environ["ENV"] = dbutils.widgets.get("ENV").strip().lower()
    os.environ["STORAGE_ACCOUNT"] = dbutils.widgets.get("STORAGE_ACCOUNT").strip()
    os.environ["METASTORE_ACCOUNT"] = (dbutils.widgets.get("METASTORE_ACCOUNT") or os.environ["STORAGE_ACCOUNT"]).strip()
except NameError:
    pass

In [0]:
from pyspark.sql import functions as F, types as T
from pyspark.sql import SparkSession, DataFrame
from pyspark.sql.utils import AnalysisException
from typing import Optional, List

# --------------------------------------------------------------------------------------
# Silver Loader
# --------------------------------------------------------------------------------------

class SilverLoader:
    """
    Silver layer for SCATS 'Traffic Signal Volume':
      - Reads Bronze table: {catalog}.{db}.raw_traffic
      - Light cleansing/standardization (region normalization, time parsing)
      - Idempotent upsert via record hash MERGE
      - Keeps the wide V00..V95 columns for Gold to reshape later
    """

    def __init__(self, conf_obj: Optional["Config"] = None):
        self.conf = conf_obj or conf
        self.catalog = self.conf.catalog
        self.db_name = self.conf.db_name
        self.spark: SparkSession = SparkSession.getActiveSession() or SparkSession.builder.getOrCreate()

        self.bronze_fqn = self.conf.table_fqn(self.conf.bronze_table)   # raw_traffic
        self.silver_fqn = self.conf.table_fqn(self.conf.silver_table)   # e.g., traffic_silver

        self._bootstrap_uc()
        self._create_silver_if_not_exists()

    # -------------------------- UC/bootstrap --------------------------
    def _bootstrap_uc(self) -> None:
        self.spark.sql(f"CREATE CATALOG IF NOT EXISTS {self.catalog}")
        self.spark.sql(f"CREATE SCHEMA IF NOT EXISTS {self.catalog}.{self.db_name}")
        self.spark.sql(f"USE CATALOG {self.catalog}")
        self.spark.sql(f"USE {self.db_name}")

    # -------------------------- table DDL -----------------------------
    def _create_silver_if_not_exists(self) -> None:
        """
        Silver schema mirrors Bronze plus cleaned columns & keys:
          - ReadingDate (DATE), ReadingTs (TIMESTAMP)
          - NM_REGION normalized
          - record_hash to support idempotent MERGE
        """
        cols96 = ",\n                ".join([f"V{i:02d} INT" for i in range(96)])
        self.spark.sql(f"""
            CREATE TABLE IF NOT EXISTS {self.silver_fqn} (
                NB_SCATS_SITE INT,
                NB_DETECTOR INT,
                QT_INTERVAL_COUNT STRING,
                {cols96},
                NM_REGION STRING,
                CT_RECORDS INT,
                QT_VOLUME_24HOUR INT,
                CT_ALARM_24HOUR INT,
                PartitionDate DATE,
                load_time TIMESTAMP,
                source_file STRING,

                -- normalized/derived
                ReadingDate DATE,
                ReadingTs TIMESTAMP,
                NM_REGION_NORM STRING,

                -- deterministic hash for idempotent upsert
                record_hash STRING
            )
            USING DELTA
            PARTITIONED BY (PartitionDate)
        """)

    # -------------------------- transforms ----------------------------
    @staticmethod
    def _parse_ts(df: DataFrame) -> DataFrame:
        ts = F.coalesce(
            F.to_timestamp(F.col("QT_INTERVAL_COUNT"), "yyyy-MM-dd HH:mm:ss"),
            F.to_timestamp(F.col("QT_INTERVAL_COUNT"), "yyyy/MM/dd HH:mm:ss"),
            F.to_timestamp(F.col("QT_INTERVAL_COUNT"), "dd/MM/yyyy HH:mm:ss"),
            F.to_timestamp(F.col("QT_INTERVAL_COUNT"), "yyyy-MM-dd"),
            F.to_timestamp(F.col("QT_INTERVAL_COUNT"), "yyyy/MM/dd"),
            F.to_timestamp(F.col("QT_INTERVAL_COUNT"))
        )
        return df.withColumn("ReadingTs", ts).withColumn("ReadingDate", F.to_date(ts))

    @staticmethod
    def _normalize_region(df: DataFrame) -> DataFrame:
        return df.withColumn("NM_REGION_NORM", F.upper(F.trim(F.col("NM_REGION"))))

    @staticmethod
    def _record_hash(df: DataFrame) -> DataFrame:
        # A stable hash across site, detector, reading date, source file.
        # If you reload the same file, MERGE will deduplicate.
        return df.withColumn(
            "record_hash",
            F.sha2(
                F.concat_ws(
                    "||",
                    F.col("NB_SCATS_SITE").cast("string"),
                    F.col("NB_DETECTOR").cast("string"),
                    F.date_format(F.col("ReadingDate"), "yyyy-MM-dd"),
                    F.coalesce(F.col("source_file"), F.lit(""))
                ),
                256
            )
        )

    def _select_and_cast(self, df: DataFrame) -> DataFrame:
        """Ensure numeric types for V00..V95 and metrics; preserve layout."""
        # Cast V00..V95
        for i in range(96):
            c = f"V{i:02d}"
            if c in df.columns:
                df = df.withColumn(c, F.col(c).cast("int"))

        # Cast tail metrics
        cast_map = {
            "NB_SCATS_SITE": "int",
            "NB_DETECTOR": "int",
            "CT_RECORDS": "int",
            "QT_VOLUME_24HOUR": "int",
            "CT_ALARM_24HOUR": "int"
        }
        for c, typ in cast_map.items():
            if c in df.columns:
                df = df.withColumn(c, F.col(c).cast(typ))

        return df

    def _project_columns(self, df: DataFrame) -> DataFrame:
        ordered = (
            ["NB_SCATS_SITE", "NB_DETECTOR", "QT_INTERVAL_COUNT"]
            + [f"V{i:02d}" for i in range(96)]
            + ["NM_REGION", "CT_RECORDS", "QT_VOLUME_24HOUR", "CT_ALARM_24HOUR",
               "PartitionDate", "load_time", "source_file",
               "ReadingDate", "ReadingTs", "NM_REGION_NORM", "record_hash"]
        )
        existing = [c for c in ordered if c in df.columns]
        return df.select(*existing)

    # -------------------------- upsert logic --------------------------
    def upsert_from_bronze(self, since_load_time: Optional[str] = None) -> int:
        """
        Incrementally upserts from Bronze into Silver.
        If since_load_time is provided (ISO timestamp string), only rows with load_time >= that are processed.
        Returns number of rows written/updated.
        """
        bronze = self.spark.table(self.bronze_fqn)

        if since_load_time:
            bronze = bronze.where(F.col("load_time") >= F.to_timestamp(F.lit(since_load_time)))

        # prepare
        df = (bronze
              .transform(self._select_and_cast)
              .transform(self._parse_ts)
              .transform(self._normalize_region)
              .transform(self._record_hash)
        )
        df = self._project_columns(df).cache()

        # MERGE keys
        # Use record_hash as a single deterministic key
        self.spark.sql(f"CREATE TABLE IF NOT EXISTS {self.silver_fqn} USING DELTA PARTITIONED BY (PartitionDate) AS SELECT * FROM (SELECT * FROM {self.silver_fqn}) WHERE 1=0")

        df.createOrReplaceTempView("__incoming_silver")

        merge_sql = f"""
        MERGE INTO {self.silver_fqn} AS tgt
        USING __incoming_silver AS src
          ON tgt.record_hash = src.record_hash
        WHEN MATCHED THEN UPDATE SET *
        WHEN NOT MATCHED THEN INSERT *
        """
        self.spark.sql(merge_sql)

        # Count written
        written = df.count()
        df.unpersist()
        return written

    # -------------------------- utilities ----------------------------
    def rebuild_all(self) -> int:
        """Full refresh: truncate and reload from Bronze."""
        self.spark.sql(f"TRUNCATE TABLE {self.silver_fqn}")
        return self.upsert_from_bronze(since_load_time=None)

    def validate(self, show_sample: int = 5) -> None:
        df = self.spark.table(self.silver_fqn)
        total = df.count()
        parts = df.select("PartitionDate").distinct().orderBy("PartitionDate").collect()
        first = parts[0]["PartitionDate"] if parts else None
        last = parts[-1]["PartitionDate"] if parts else None
        print(f"✅ {self.silver_fqn}: {total} rows across {len(parts)} partitions (first={first}, last={last}).")
        if show_sample:
            display(df.limit(show_sample))

# --------------------------------------------------------------------------------------
# Example usage (run as a notebook cell)
# --------------------------------------------------------------------------------------

SL = SilverLoader(conf)  # uses ENV from 1_config.py widgets if present
rows = SL.rebuild_all()
print(f"🔁 Silver rows written: {rows}")
SL.validate()
