# 1. Setting up the environment and loading the parameters.

### a. Importing the libraries and configrations

In [0]:
# Importing required files & libraries
import time
import json
from datetime import datetime, timedelta, timezone
import logging
from pyspark.sql.types import *
from pyspark.sql.functions import *
from concurrent.futures import ThreadPoolExecutor, as_completed
import schemas  
from schemas import JsonSchemaValidator
import pandas as pd
from pyspark.sql.streaming import StreamingQuery
from typing import *

In [0]:
# Configure logging
vietnam_tz = timezone(timedelta(hours=7))
logging.Formatter.converter = lambda *args: datetime.now(vietnam_tz).timetuple()
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

### b. Defining widgets and parameters & loading configs

In [0]:
# Defining widgets
dbutils.widgets.text("environment", "", "")
dbutils.widgets.text("job_name", "", "")
dbutils.widgets.text("eventhub_namespace", "", "")
# Fetching widgets values
environment = dbutils.widgets.get("environment")
job_name = dbutils.widgets.get("job_name")
eventhub_namespace = dbutils.widgets.get("eventhub_namespace")

In [0]:
config = open("../configs/config.json")
settings = json.load(config)

In [0]:
catalog_name = settings[environment]['catalogName']
bronze_schema = settings[environment]['bronzeSchema']
bootstrap_servers = settings[environment][f"eventhubs_serverDetails_{eventhub_namespace}"]
scope_name = settings[environment][f"scope_name"]
secret_key_name = settings[environment][f"secret_key_name_{eventhub_namespace}"]
eventhubs_connection_string = dbutils.secrets.get(scope=scope_name, key=secret_key_name)
eventhub_raw_path = settings[environment]["eventhub_raw_path"]
eventhub_raw_checkpoint = settings[environment]["eventhub_raw_checkpoint"]
eventhub_bronze_schema_location = settings[environment]["eventhub_bronze_schema_location"]
eventhub_bronze_checkpoint = settings[environment]["eventhub_bronze_checkpoint"]
eventhub_invalid_records_checkpoint= settings[environment]["eventhub_invalid_records_checkpoint"]
eventhub_invalid_records_path= settings[environment]["eventhub_invalid_records_path"]

In [0]:
logger.info(
    "\n".join(
        [
            "Configuration Loaded:",
            f"  catalog_name                        : {catalog_name}",
            f"  bronze_schema                       : {bronze_schema}",
            f"  bootstrap_servers                   : {bootstrap_servers}",
            f"  eventhubs_connection_string         : {eventhubs_connection_string}",
            f"  eventhub_raw_path                   : {eventhub_raw_path}",
            f"  eventhub_raw_checkpoint             : {eventhub_raw_checkpoint}",
            f"  eventhub_bronze_schema_location     : {eventhub_bronze_schema_location}",
            f"  eventhub_bronze_checkpoint          : {eventhub_bronze_checkpoint}",
            f"  eventhub_invalid_records_checkpoint : {eventhub_invalid_records_checkpoint}",
            f"  eventhub_invalid_records_path       : {eventhub_invalid_records_path}",
        ]
    )
)

### c. Setting Up connection string for EventHub conection

In [0]:
eh_sasl = f'kafkashaded.org.apache.kafka.common.security.plain.PlainLoginModule required username=\"$ConnectionString\"password=\"{eventhubs_connection_string}";'

# 2. Reading data from the EventHub and writing it to bronze layer

In [0]:
# ---------- Define Output Schema for Validation Results ----------
validate_schema = StructType(
    [
        StructField("is_valid", BooleanType(), True),
        StructField("error", StringType(), True),
    ]
)


# ---------- Pandas UDF for Schema Validation ----------
def create_validation_udf(validators_dict: Dict[str, any]) -> Callable[[pd.Series, pd.Series], pd.DataFrame]:
    try:
        @pandas_udf(validate_schema)
        def validate_all(json_series: pd.Series, hub_series: pd.Series) -> pd.DataFrame:
            is_valid_list, error_list = [], []
            for json_str, hub in zip(json_series, hub_series):
                try:
                    obj = json.loads(json_str)
                    validators_dict[hub].validate_instance(obj)
                    is_valid_list.append(True)
                    error_list.append(None)
                except Exception as e:
                    is_valid_list.append(False)
                    error_list.append(str(e))
            return pd.DataFrame({"is_valid": is_valid_list, "error": error_list})

        return validate_all
    except Exception as e:
        logger.error(f"Failed to create validation UDF: {e}")
        raise


# ---------- Functions ----------


def get_eventhub_list(catalog_name: str, eventhub_namespace: str) -> List[str]:
    """
    Fetch list of raw table names (Event Hubs) from lookup table for a given namespace.
    """
    query = f"""
        SELECT distinct table_name
        FROM {catalog_name}.default.lookup_table_source2raw
        WHERE source_name = 'EVENTHUB'
          AND eventhub_namespace = '{eventhub_namespace}'
    """
    try:
        df = spark.sql(query)
        event_hubs_list = [
            row["table_name"] for row in df.select("table_name").collect()
        ]
        logger.info(
            f"Fetched Event Hub tables for eventhub namespace '{eventhub_namespace}': {event_hubs_list}"
        )
        return event_hubs_list
    except Exception as e:
        logger.error(f"Failed to fetch Event Hub list for namespace {eventhub_namespace}: {e}")
        raise


def get_validators(event_hubs_list: List[str]) -> Dict[str, any]:
    """
    Dynamically fetch validator objects from the schemas module for each hub.
    """
    try:
        validators = {
            hub: getattr(schemas, hub)
            for hub in event_hubs_list
            if hasattr(schemas, hub)  # Only include if the schema is defined
        }
        logger.info(f"Validators loaded for hubs: {list(validators.keys())}")
        return validators
    except Exception as e:
        logger.error(f"Failed to build validators from schemas module: {str(e)}")
        raise


def read_eventhub_stream(hubs: str) -> DataFrame:
    """
    Creates a Kafka-based Spark readStream using the given comma-separated list of topics (Event Hubs).
    """
    try:
        return (
            spark.readStream.format("kafka")
            .option("subscribe", hubs)
            .option("kafka.bootstrap.servers", bootstrap_servers)
            .option("startingOffsets", "earliest")
            .option("kafka.sasl.mechanism", "PLAIN")
            .option("kafka.security.protocol", "SASL_SSL")
            .option("kafka.sasl.jaas.config", eh_sasl)
            .option("failOnDataLoss", "false")
            .option("kafka.request.timeout.ms", "60000")
            .option("kafka.session.timeout.ms", "30000")
            .load()
        )
    except Exception as e:
        logger.error(f"Failed to start Event Hub stream for hubs '{hubs}': {e}")
        raise


def decode_stream_data(df: DataFrame) -> DataFrame:
    """
    Decodes Kafka messages to readable JSON string and extracts topic name as `hub`.
    """
    try:
        return df.select(
            col("topic").alias("hub"),
            decode(col("value"), "UTF-8").cast("string").alias("json_str"),
        )
    except Exception as e:
        logger.error(f"Failed to decode Kafka stream data: {e}")
        raise


def write_valid_invalid_streams(
    hub: str,
    decoded_df: DataFrame,
    validate_all: Callable[[pd.Series, pd.Series], pd.DataFrame],
    stream_queries: List[Tuple[str, StreamingQuery]]
) -> None:
    """
    For a given hub, splits valid and invalid records and writes them to raw and invalid paths respectively.
    Appends the active stream queries to `stream_queries` for monitoring.
    """
    # Apply schema validation using UDF
    validated_df = (
        decoded_df.withColumn("validation", validate_all(col("json_str"), col("hub")))
        .withColumn("is_valid", col("validation.is_valid"))
        .withColumn("error", col("validation.error"))
        .drop("validation")
    )

    # Filter only records from this hub
    filtered_df = validated_df.filter(col("hub") == hub)

    # Separate valid records
    valid_df = filtered_df.filter(col("is_valid") == True).drop(
        "is_valid", "error", "hub"
    )

    # Add timestamp to invalid records
    invalid_df = (
        filtered_df.filter(col("is_valid") == False)
        .withColumn(
            "ingestion_time",
            from_utc_timestamp(current_timestamp(), "Asia/Ho_Chi_Minh"),
        )
        .drop("is_valid", "hub")
    )

    # Write valid records to raw path (as text)
    try:
        write_valid = (
            valid_df.writeStream.format("text")
            .option("checkpointLocation", f"{eventhub_raw_checkpoint}/{hub}")
            .option("path", f"{eventhub_raw_path}/{hub}")
            .outputMode("append")
            .trigger(processingTime="15 minutes")
            .start()
        )
    except Exception as e:
        logger.error(f"Failed to start write stream for valid records of {hub}: {e}")
        raise

    # Write invalid records to Delta
    try:
        write_invalid = (
        invalid_df.writeStream.format("delta")
        .option("checkpointLocation", f"{eventhub_invalid_records_checkpoint}/{hub}")
        .option("path", f"{eventhub_invalid_records_path}/{hub}")
        .outputMode("append")
        .trigger(processingTime="15 minutes")
        .start()
        )
    except Exception as e:
        logger.error(f"Failed to start write stream for invalid records of {hub}: {e}")
        raise

    # Store queries for later monitoring
    stream_queries.extend(
        [(f"{hub}_raw_stream", write_valid), (f"{hub}_invalid_stream", write_invalid)]
    )


def start_bronze_stream_for_hub(
    hub: str,
    catalog_name: str,
    bronze_schema: str,
    stream_queries: List[Tuple[str, StreamingQuery]]
) -> None:
    """
    Reads raw data from Event Hub text files and writes to a Delta table in Bronze layer.
    Handles dynamic or missing `ProcDate` field.
    """
    df = (
        spark.readStream.format("cloudFiles")
        .option("cloudFiles.format", "json")
        .option("cloudFiles.inferColumnTypes", "true")
        .option("cloudFiles.schemaLocation", f"{eventhub_bronze_schema_location}/{hub}")
        .option("mergeSchema", "true")
        .load(f"{eventhub_raw_path}/{hub}")
    )

    # Normalize or add ProcDate
    if "ProcDate" not in df.columns:
        df = df.withColumn(
            "ProcDate",
            date_format(
                from_utc_timestamp(current_timestamp(), "Asia/Ho_Chi_Minh"),
                "yyyyMMddHHmmss",
            ),
        )
    else:
        df = df.withColumn(
            "ProcDate",
            when(col("ProcDate").rlike("^[0-9]{14}$"), col("ProcDate")).otherwise(
                date_format(
                    to_timestamp(col("ProcDate"), "yyyy-MM-dd HH:mm:ss"),
                    "yyyyMMddHHmmss",
                )
            ),
        )
    df=df.withColumn(
                "source_metadata",
                concat_ws(
                    "|",
                    col("_metadata.file_path"),
                    from_utc_timestamp(col("_metadata.file_modification_time"), "Asia/Ho_Chi_Minh").cast("string")
                )
            )

    def write_batch_with_rowcount_audit(batch_df, batch_id):
        try:
            if batch_df.count()>0:
                batch_df.write.format("delta").mode("append").partitionBy("ProcDate").option("mergeSchema", "true").saveAsTable(f"{catalog_name}.{bronze_schema}.{hub}")
                batch_ts = from_utc_timestamp(current_timestamp(), "Asia/Ho_Chi_Minh")
                audit_df = (
                    batch_df
                    .withColumn("proc_date", substring("ProcDate", 1, 8))  # extract yyMMdd from yyyyMMddHHmmss
                    .groupBy("proc_date")
                    .agg(
                        count("*").alias("row_count")
                    )
                    .withColumn("hub", lit(hub))
                    .withColumn("batch_id", lit(batch_id))
                    .withColumn("batch_timestamp", lit(batch_ts))
                )

                audit_df.write.format("delta").mode("append").saveAsTable(f"{catalog_name}.default.eventhub_audit_log")
        except Exception as e:
            logger.error(f"Failed to write audit log for {hub}: {e}")
            raise

    # Write to Bronze Delta table
    try:
        write_bronze = (
            df.writeStream
            .foreachBatch(write_batch_with_rowcount_audit)
            .option("checkpointLocation", f"{eventhub_bronze_checkpoint}/{hub}")
            .trigger(processingTime="15 minutes")
            .start()
        )
        logger.info(f"Bronze stream started for {hub}")
        write_bronze_marker_if_missing(hub)

    except Exception as e:
        logger.error(f"Failed to start write bronze stream of {hub}: {e}")
        raise
    stream_queries.append((f"{hub}_bronze_stream", write_bronze))


def write_bronze_marker_if_missing(hub: str) -> None:
    """
    Writes a marker file named `_bronze_started` in the Event Hub's raw directory if it doesn't already exist.
    This marker indicates that the Bronze stream has been initiated for the given hub.
    """
    marker_path = f"{eventhub_raw_path}/{hub}/_bronze_started"
    try:
        if not any(f.name == "_bronze_started" for f in dbutils.fs.ls(f"{eventhub_raw_path}/{hub}")):
            dbutils.fs.put(marker_path, "", overwrite=False)
            logger.info(f"Wrote bronze started marker for hub: {hub}")
    except Exception as e:
        logger.error(f"Failed to write bronze marker for {hub}: {e}")
        raise


def is_data_available(path: str) -> bool:
    """
    Checks whether raw data exists or bronze already started for this hub.
    """
    try:
        if any(f.name == "_bronze_started" for f in dbutils.fs.ls(path)):
            return True  # Skip expensive listing if already started
        files = dbutils.fs.ls(path)
        return any(
            f.size > 0 and not f.name.startswith("_") and not f.name.startswith(".")
            for f in files
        )
    except Exception as e:
        logger.error(f"Failed to check files at path {path}: {e}")
        return False


def check_and_start_stream(hub: str) -> None:
    """
    Keep checking until data is available for the hub.
    Once data is found, trigger the Bronze stream.
    """
    path = f"{eventhub_raw_path}/{hub}"
    while True:
        if is_data_available(path):
            logger.info(f"Data found for hub: {hub} — starting bronze stream")
            try:
                start_bronze_stream_for_hub(hub, catalog_name, bronze_schema, stream_queries)
            except Exception as e:
                logger.error(f"Error starting bronze stream for {hub}: {e}")
            break  # Exit loop after starting stream
        else:
            logger.info(f"Waiting for data for hub: {hub}. Retrying in 60s...")
            time.sleep(60)


def monitor_streams(
    queries: List[Tuple[str, StreamingQuery]]
) -> None:
    """
    Monitors multiple Spark Streaming queries in parallel.

    """
    def wait_for_query(name, query):
        try:
            query.awaitTermination()
        except Exception as e:
            logger.error(f"Streaming query '{name}' failed: {e}")
            raise RuntimeError(f"Query '{name}' terminated with error.")

    with ThreadPoolExecutor() as executor:
        futures = [executor.submit(wait_for_query, name, query) for name, query in queries]
        for f in futures:
            f.result()

In [0]:
# ---------- Main Function ----------
def main():
    global stream_queries
    stream_queries = [] # To track all active streaming queries for monitoring

    # Step 1: Get Event Hub metadata and prepare validators
    # Retrieve list of Event Hub table names from the lookup table
    event_hubs_list = get_eventhub_list(catalog_name, eventhub_namespace)

    # Checking if no Event Hubs are found for the given namespace
    if not event_hubs_list:
        logger.error(f"No Event Hub tables found for namespace '{eventhub_namespace}'. Exiting.")
        raise ValueError(f"No Event Hub tables found for namespace '{eventhub_namespace}'.")

    # Dynamically fetch JSON schema validators for each hub
    validators = get_validators(event_hubs_list)

    # Check for hubs without schemas
    hubs_without_schema = [hub for hub in event_hubs_list if hub not in validators]
    if hubs_without_schema:
        logger.error(f"The following hubs do not have schemas defined: {hubs_without_schema}")
        raise ValueError(f"No schema defined for hubs: {hubs_without_schema}")

    # Create the validation UDF using the validator dictionary
    validate_all = create_validation_udf(validators)
    hubs = ",".join(event_hubs_list)

    # Step 2: Read Kafka stream and decode,Validate and Write valid and invalid data streams per hub
    for hub in event_hubs_list:
        logger.info(f"Initialized read stream for the hub: {hub}")
        raw_eventhub_df = read_eventhub_stream(hub)
        decoded_df = decode_stream_data(raw_eventhub_df)
        logger.info(f"Initialized write stream for the hub: {hub}")
        write_valid_invalid_streams(hub, decoded_df, validate_all, stream_queries)

    # Step 3: trigger bronze streams in parallel when data is found
    logger.info("Waiting for raw data to trigger bronze ingestion per hub...")
    with ThreadPoolExecutor(max_workers=20) as executor:
        executor.map(check_and_start_stream, event_hubs_list)
    logger.info("All bronze streams triggered")
    
    # Step 4: Monitor all stream queries continuously
    monitor_streams(stream_queries)


# ---------- Entry Point ----------
if __name__ == "__main__":
    try:
        main()
    except Exception as e:
        logger.error(f"Pipeline failed with unexpected error: {e}")
        raise