In [0]:
%run ./1_config.py

In [0]:
# --- inline params (dev/qa) ---
import os, importlib
try:
    dbutils.widgets.dropdown("ENV", "dev", ["dev", "qa"], "Environment")
    dbutils.widgets.dropdown("STORAGE_ACCOUNT", "trafficsa2", ["trafficsa2", "trafficsaqa"], "Storage account")
    dbutils.widgets.text("METASTORE_ACCOUNT", "trafficsa2", "Metastore account")
    os.environ["ENV"] = dbutils.widgets.get("ENV").strip().lower()
    os.environ["STORAGE_ACCOUNT"] = dbutils.widgets.get("STORAGE_ACCOUNT").strip()
    os.environ["METASTORE_ACCOUNT"] = (dbutils.widgets.get("METASTORE_ACCOUNT") or os.environ["STORAGE_ACCOUNT"]).strip()
except NameError:
    pass


In [0]:
from pyspark.sql import functions as F, types as T
from pyspark.sql.utils import AnalysisException
from datetime import datetime, timedelta
import re
from typing import Optional, List
from pyspark.sql import SparkSession, DataFrame
from pyspark.sql.functions import (
    col, current_timestamp, to_date, input_file_name,
    to_timestamp, coalesce
)

class LoadRawTraffic:
    """
    Bronze loader for SCATS 'Traffic Signal Volume' CSVs.
    Writes to {catalog}.{db}.raw_traffic (schema created in 2_setup.py).
    """

    def __init__(self, catalog: str, table_name: str, checkpoint_dir: Optional[str] = None, env: Optional[str] = None):
        self.conf = Config(env or catalog)
        self.catalog = catalog
        self.db_name = self.conf.db_name
        self.landing_zone = self.conf.raw_data_path
        self.table_name = table_name
        self.table_fqn = self.conf.table_fqn(table_name)
        base_chk = f"{self.conf.checkpoint_base}/bronze"
        self.checkpoint_dir = checkpoint_dir or f"{base_chk}/{self.table_name.replace('.', '_')}"
        self.spark: SparkSession = SparkSession.getActiveSession() or SparkSession.builder.getOrCreate()

        # Ensure UC context
        self._bootstrap_uc()

        # Explicit schema to avoid drift (strings vs ints)
        self.schema = self._build_schema()

    # --------------------- UC bootstrap & table ---------------------
    def _bootstrap_uc(self) -> None:
        self.spark.sql(f"CREATE CATALOG IF NOT EXISTS {self.catalog}")
        self.spark.sql(f"CREATE SCHEMA IF NOT EXISTS {self.catalog}.{self.db_name}")
        self.spark.sql(f"USE CATALOG {self.catalog}")
        self.spark.sql(f"USE {self.db_name}")

    @staticmethod
    def _raw_table_columns_sql() -> str:
        cols = [f"V{i:02d} INT" for i in range(96)]
        return ",\n                ".join(cols)

    def _create_raw_table_if_not_exists(self) -> None:
        cols96 = self._raw_table_columns_sql()
        self.spark.sql(f"""
            CREATE TABLE IF NOT EXISTS {self.table_fqn} (
                NB_SCATS_SITE INT,
                QT_INTERVAL_COUNT STRING,
                NB_DETECTOR INT,
                {cols96},
                NM_REGION STRING,
                CT_RECORDS INT,
                QT_VOLUME_24HOUR INT,
                CT_ALARM_24HOUR INT,
                PartitionDate DATE,
                load_time TIMESTAMP,
                source_file STRING
            )
            USING DELTA
            PARTITIONED BY (PartitionDate)
        """)

    def create_db(self) -> None:
        self._bootstrap_uc()
        self._create_raw_table_if_not_exists()

    # --------------------- schema & transforms ---------------------
    def _build_schema(self) -> T.StructType:
        fields = [
            T.StructField("NB_SCATS_SITE", T.IntegerType(), True),
            T.StructField("QT_INTERVAL_COUNT", T.StringType(), True),
            T.StructField("NB_DETECTOR", T.IntegerType(), True),
        ]
        for i in range(96):
            fields.append(T.StructField(f"V{i:02d}", T.IntegerType(), True))
        fields += [
            T.StructField("NM_REGION", T.StringType(), True),
            T.StructField("CT_RECORDS", T.IntegerType(), True),
            T.StructField("QT_VOLUME_24HOUR", T.IntegerType(), True),
            T.StructField("CT_ALARM_24HOUR", T.IntegerType(), True),
        ]
        return T.StructType(fields)

    @staticmethod
    def _drop_if_exists(df: DataFrame, cols: List[str]) -> DataFrame:
        for c in cols:
            if c in df.columns:
                df = df.drop(c)
        return df

    @staticmethod
    def _cast_volume_cols_int(df: DataFrame) -> DataFrame:
        # Safe cast in case CSV parser brought them as strings
        for i in range(96):
            c = f"V{i:02d}"
            if c in df.columns:
                df = df.withColumn(c, col(c).cast("int"))
        return df

    @staticmethod
    def _parse_partition_date(df: DataFrame) -> DataFrame:
        """
        PartitionDate derived from QT_INTERVAL_COUNT.
        Handles date or timestamp strings with several common formats.
        """
        parsed_ts = coalesce(
            to_timestamp(col("QT_INTERVAL_COUNT"), "yyyy-MM-dd HH:mm:ss"),
            to_timestamp(col("QT_INTERVAL_COUNT"), "yyyy/MM/dd HH:mm:ss"),
            to_timestamp(col("QT_INTERVAL_COUNT"), "dd/MM/yyyy HH:mm:ss"),
            to_timestamp(col("QT_INTERVAL_COUNT"), "yyyy-MM-dd"),
            to_timestamp(col("QT_INTERVAL_COUNT"), "yyyy/MM/dd"),
            to_timestamp(col("QT_INTERVAL_COUNT"))  # fallback
        )
        return df.withColumn("PartitionDate", to_date(parsed_ts))

    @staticmethod
    def _normalize_region(df: DataFrame) -> DataFrame:
        return df.withColumn("NM_REGION", F.upper(F.trim(col("NM_REGION"))))

    def _order_columns(self, df: DataFrame) -> DataFrame:
        ordered = (
            ["NB_SCATS_SITE", "QT_INTERVAL_COUNT", "NB_DETECTOR"]
            + [f"V{i:02d}" for i in range(96)]
            + ["NM_REGION", "CT_RECORDS", "QT_VOLUME_24HOUR", "CT_ALARM_24HOUR",
               "PartitionDate", "load_time", "source_file"]
        )
        # keep only those that exist (defensive)
        existing = [c for c in ordered if c in df.columns]
        return df.select(*existing)

    # --------------------- batch load ---------------------
    def batch_load(self, start_date: str = "2025-05-01", end_date: str = "2025-05-01") -> None:
        """
        Backfill daily CSVs in [start_date, end_date], inclusive.
        Expects files named VSDATA_YYYYMMDD.csv under self.landing_zone.
        """
        print(f"📦 Batch load: {start_date} → {end_date} from {self.landing_zone}")
        try:
            start = datetime.strptime(start_date, "%Y-%m-%d").date()
            end = datetime.strptime(end_date, "%Y-%m-%d").date()
        except Exception:
            raise ValueError("Dates must be in YYYY-MM-DD format")

        if end < start:
            raise ValueError("end_date must be >= start_date")

        total_rows = 0
        for delta in range((end - start).days + 1):
            day = start + timedelta(days=delta)
            filename = f"VSDATA_{day.strftime('%Y%m%d')}.csv"
            path = f"{self.landing_zone}/{filename}"
            try:
                df = (self.spark.read
                        .option("header", True)
                        .option("inferSchema", False)
                        .schema(self.schema)
                        .csv(path))

                if "QT_INTERVAL_COUNT" not in df.columns:
                    raise ValueError("Column QT_INTERVAL_COUNT is missing in the file.")

                df = (df.transform(self._drop_if_exists, ["_rescued_data"])
                        .transform(self._cast_volume_cols_int)
                        .transform(self._parse_partition_date)
                        .transform(self._normalize_region)
                        .withColumn("load_time", current_timestamp())
                        .withColumn("source_file", input_file_name()))
                df = self._order_columns(df)

                (df.write.format("delta")
                    .mode("append")
                    .option("mergeSchema", "false")
                    .partitionBy("PartitionDate")
                    .saveAsTable(self.table_fqn))

                count = df.count()
                total_rows += count
                print(f"  ✅ {filename}: {count} rows")
            except AnalysisException as e:
                print(f"  ⚠️ {filename}: skipped (AnalysisException: {str(e).splitlines()[0]})")
            except Exception as e:
                print(f"  ❌ {filename}: failed — {e}")
                raise
        print(f"Done. Total rows appended: {total_rows}")

    # --------------------- streaming load (Auto Loader) ---------------------
    def stream_load(self, file_pattern: str = "VSDATA_202506*.csv", trigger_once: bool = True, reset_checkpoint: bool = True) -> None:
        """
        Auto Loader over matching files, e.g. 'VSDATA_202508*.csv'.
        - trigger_once=True  -> one micro-batch then stop
        - trigger_once=False -> availableNow (ingest all available then stop)
        """
        print(f"🌊 Streaming load for pattern {file_pattern}")
        stream_path = f"{self.landing_zone}/{file_pattern}"
        stream_chk = f"{self.checkpoint_dir}/streaming"
        schema_loc = f"{stream_chk}/schema"

        if reset_checkpoint:
            print(f"🧹 Cleaning checkpoint: {stream_chk}")
            try:
                dbutils.fs.rm(stream_chk, recurse=True)
                print("  ✅ Checkpoint cleared.")
            except Exception as e:
                print(f"  ⚠️ Could not clear checkpoint ({e}). Continuing...")

        reader = (self.spark.readStream
                    .format("cloudFiles")
                    .option("cloudFiles.format", "csv")
                    .option("cloudFiles.inferColumnTypes", "true")
                    .option("cloudFiles.schemaLocation", schema_loc)
                    .option("cloudFiles.schemaEvolutionMode", "rescue")
                    .option("header", True)
                    .load(stream_path))

        stream_df = (reader
                     .transform(self._drop_if_exists, ["_rescued_data"])
                     .transform(self._cast_volume_cols_int)
                     .transform(self._parse_partition_date)
                     .transform(self._normalize_region)
                     .withColumn("load_time", current_timestamp())
                     .withColumn("source_file", col("_metadata.file_path")))
        stream_df = self._order_columns(stream_df)

        writer = (stream_df.writeStream
                    .format("delta")
                    .option("checkpointLocation", stream_chk)
                    .option("mergeSchema", "false")
                    .partitionBy("PartitionDate")
                    .outputMode("append"))

        if trigger_once:
            query = writer.trigger(once=True).toTable(self.table_fqn)
        else:
            query = writer.trigger(availableNow=True).toTable(self.table_fqn)

        query.awaitTermination()
        print("✅ Streaming load completed.")

    # --------------------- validation / maintenance ---------------------
    def validate_table(self) -> None:
        print(f"🔎 Validating {self.table_fqn} ...")
        try:
            df = self.spark.table(self.table_fqn)
            total = df.count()
            parts = df.select("PartitionDate").distinct().orderBy("PartitionDate").collect()
            first = parts[0]["PartitionDate"] if parts else None
            last  = parts[-1]["PartitionDate"] if parts else None
            print(f"✅ {self.table_fqn}: {total} rows across {len(parts)} partitions "
                  f"(first={first}, last={last}).")
        except Exception as e:
            print(f"❌ Validation failed: {e}")
            raise

    def vacuum_optimize(self, zorder_cols=None):
        zorder_cols = zorder_cols or ["PartitionDate", "NB_SCATS_SITE", "NB_DETECTOR"]
        self.spark.sql(f"OPTIMIZE {self.table_fqn} ZORDER BY ({', '.join(zorder_cols)})")
        print("Optimize + ZORDER done.")


# --------------------- example run (bronze) ---------------------
conf = Config()  # uses ENV/widgets from 1_config.py
bronze = LoadRawTraffic(catalog=conf.catalog, table_name=conf.bronze_table, env=conf.env)

bronze.create_db()
# Historical backfill (adjust dates as needed)
bronze.batch_load(start_date="2025-05-01", end_date="2025-05-01")
# Ingest a set of new drops via Auto Loader
bronze.stream_load(file_pattern="VSDATA_202506*.csv", trigger_once=True)
bronze.validate_table()
print("✅ bronze load/validate OK")