In [0]:
%run ./1_config.py

In [0]:
# --- inline params (dev/qa) ---
import os, importlib
try:
    dbutils.widgets.dropdown("ENV", "dev", ["dev", "qa"], "Environment")
    dbutils.widgets.dropdown("STORAGE_ACCOUNT", "trafficsa2", ["trafficsa2", "trafficsaqa"], "Storage account")
    dbutils.widgets.text("METASTORE_ACCOUNT", "trafficsa2", "Metastore account")
    os.environ["ENV"] = dbutils.widgets.get("ENV").strip().lower()
    os.environ["STORAGE_ACCOUNT"] = dbutils.widgets.get("STORAGE_ACCOUNT").strip()
    os.environ["METASTORE_ACCOUNT"] = (dbutils.widgets.get("METASTORE_ACCOUNT") or os.environ["STORAGE_ACCOUNT"]).strip()
except NameError:
    pass

In [0]:
# ---- imports ----
from typing import Optional, List
from pyspark.sql import SparkSession, DataFrame
from pyspark.sql import functions as F
from pyspark.sql.functions import (
    col, lit, when, to_timestamp, from_unixtime, unix_timestamp, expr,
    concat_ws, sha2, upper, trim, coalesce, current_timestamp
)

FIFTEEN_MIN = 15 * 60  # seconds

class SilverLoader:
    def __init__(self, conf: Optional[Config] = None):
        self.conf = conf or Config()
        self.spark = SparkSession.getActiveSession() or SparkSession.builder.getOrCreate()
        self.catalog = self.conf.catalog
        self.db = self.conf.db_name
        self.bronze_fqn = self.conf.table_fqn(self.conf.bronze_table)
        self.silver_table = self.conf.silver_table            # canonical name from Config
        self.silver_fqn = self.conf.table_fqn(self.silver_table)
        self.chk_dir = f"{self.conf.checkpoint_base}/silver/{self.silver_table.replace('.','_')}"

    def _bootstrap_uc(self):
        s = self.spark
        s.sql(f"CREATE CATALOG IF NOT EXISTS {self.catalog}")
        s.sql(f"CREATE SCHEMA IF NOT EXISTS {self.catalog}.{self.db}")
        s.sql(f"USE CATALOG {self.catalog}")
        s.sql(f"USE {self.db}")

    def _create_silver_if_missing(self):
        self.spark.sql(f"""
        CREATE TABLE IF NOT EXISTS {self.silver_fqn} (
            DetKey STRING,
            Fact_ID STRING,
            NM_REGION STRING,
            NB_SCATS_SITE INT,
            NB_DETECTOR INT,
            Interval_EndTime TIMESTAMP,
            Volume BIGINT,
            load_time TIMESTAMP,
            PartitionDate DATE
        )
        USING DELTA
        TBLPROPERTIES (
          delta.enableChangeDataFeed = true,
          delta.constraints.nonneg_volume = 'Volume >= 0'
        )
        """)

    # ---------- transforms ----------
    @staticmethod
    def _clean_negatives(df: DataFrame) -> DataFrame:
        for i in range(96):
            c = f"V{i:02d}"
            if c in df.columns:
                df = df.withColumn(c, when(col(c).cast("long") < 0, lit(0)).otherwise(col(c).cast("long")))
        return df

    @staticmethod
    def _unpivot_96(df: DataFrame) -> DataFrame:
        labels = [f"V{i:02d}" for i in range(96)]
        present = [c for c in labels if c in df.columns]
        n = len(present)
        stack_expr = ", ".join([f"'{c}', {c}" for c in present])
        out = df.selectExpr(
            "NB_SCATS_SITE",
            "QT_INTERVAL_COUNT",
            "NB_DETECTOR",
            "NM_REGION",
            f"stack({n}, {stack_expr}) as (Interval_Label, Volume)"
        )
        out = out.withColumn("Interval_Index", expr("int(substring(Interval_Label, 2, 2))"))
        return out.drop("Interval_Label")

    @staticmethod
    def _compute_interval_end(df: DataFrame) -> DataFrame:
        base = df.withColumn("QT_TS", to_timestamp("QT_INTERVAL_COUNT"))
        out = base.withColumn(
            "Interval_EndTime",
            from_unixtime(
                unix_timestamp("QT_TS") - ((95 - col("Interval_Index")) * FIFTEEN_MIN)
            ).cast("timestamp")
        )
        return out.drop("QT_TS")

    @staticmethod
    def _normalize_and_key(df: DataFrame) -> DataFrame:
        df = (df.withColumn("NM_REGION", upper(trim(col("NM_REGION"))))
                .withColumn("Volume", coalesce(col("Volume").cast("long"), lit(0))))
        end_norm = F.from_unixtime((F.unix_timestamp("Interval_EndTime")/FIFTEEN_MIN).cast("bigint")*FIFTEEN_MIN).cast("timestamp")
        df = df.withColumn("Interval_EndTime", end_norm)
        detkey = sha2(concat_ws("§","NM_REGION","NB_SCATS_SITE","NB_DETECTOR", col("Interval_EndTime").cast("string")), 256)
        df = (df.withColumn("DetKey", detkey)
                .withColumn("Fact_ID", detkey)
                .withColumn("load_time", current_timestamp())
                .withColumn("PartitionDate", F.to_date("Interval_EndTime")))
        return df.dropDuplicates(["DetKey"])

    def transform(self, bronze_batch: DataFrame) -> DataFrame:
        df = bronze_batch
        df = self._clean_negatives(df)
        df = self._unpivot_96(df)
        df = self._compute_interval_end(df)
        df = self._normalize_and_key(df)
        return df.select("DetKey","Fact_ID","NM_REGION","NB_SCATS_SITE","NB_DETECTOR","Interval_EndTime","Volume","load_time","PartitionDate")

    # ---------- MERGE ----------
    def _merge_upsert(self, micro: DataFrame):
        micro.createOrReplaceTempView("silver_updates")
        self.spark.sql(f"""
        MERGE INTO {self.silver_fqn} AS t
        USING silver_updates AS s
        ON t.DetKey = s.DetKey
        WHEN MATCHED THEN UPDATE SET
          t.Fact_ID = s.Fact_ID,
          t.NM_REGION = s.NM_REGION,
          t.NB_SCATS_SITE = s.NB_SCATS_SITE,
          t.NB_DETECTOR = s.NB_DETECTOR,
          t.Interval_EndTime = s.Interval_EndTime,
          t.Volume = s.Volume,
          t.load_time = s.load_time,
          t.PartitionDate = s.PartitionDate
        WHEN NOT MATCHED THEN INSERT (
          DetKey, Fact_ID, NM_REGION, NB_SCATS_SITE, NB_DETECTOR, Interval_EndTime, Volume, load_time, PartitionDate
        ) VALUES (
          s.DetKey, s.Fact_ID, s.NM_REGION, s.NB_SCATS_SITE, s.NB_DETECTOR, s.Interval_EndTime, s.Volume, s.load_time, s.PartitionDate
        )
        """)

    def make_foreach_batch(self):
        def foreach_batch(batch_df: DataFrame, batch_id: int):
            if batch_df.isEmpty():
                print(f"Batch {batch_id}: no data.")
                return
            print(f"Batch {batch_id}: transforming + MERGE upsert ...")
            self._merge_upsert(self.transform(batch_df))
            print(f"Batch {batch_id}: upsert complete.")
        return foreach_batch

    def run_once(self):
        self._bootstrap_uc()
        self._create_silver_if_missing()
        src = self.spark.readStream.format("delta").table(self.bronze_fqn)
        q = (src.writeStream
                 .foreachBatch(self.make_foreach_batch())
                 .option("checkpointLocation", self.chk_dir)
                 .outputMode("update")
                 .trigger(once=True)
                 .start())
        q.awaitTermination()

    def validate(self):
        df = self.spark.table(self.silver_fqn)
        by_detkey = df.groupBy("DetKey").count().filter("count > 1").count()
        assert by_detkey == 0, "Silver contains duplicate DetKey rows."
        kdup = (df.groupBy("NM_REGION","NB_SCATS_SITE","NB_DETECTOR","Interval_EndTime")
                  .count().filter("count > 1").count())
        assert kdup == 0, "Silver contains duplicate detector-hour rows."
        print(f"✅ {self.silver_fqn}: no duplicate keys.")


# run cell (silver)
loader = SilverLoader()   # reads same ENV
loader.run_once()         # trigger(once=True) + awaitTermination()
loader.validate()         # raises if dup keys etc.
print("✅ silver upsert/validate OK")