In [None]:
from pyspark.sql import functions as F
from pyspark.sql.window import Window

# ============================================================
# CONFIGURATION
# ============================================================

# Get parameters
dbutils.widgets.text("catalog_name", "cloudfastener")
dbutils.widgets.text("company_id", "")

catalog_name = dbutils.widgets.get("catalog_name").strip()
company_id_param = dbutils.widgets.get("company_id").strip()

if not catalog_name:
    raise ValueError("Missing required param: catalog_name")

# Set timezone
spark.conf.set("spark.sql.session.timeZone", "UTC")

# Calculate job date and processing window
# job_date = current day at 00:00:00 UTC
# window = [job_date - 1 day, job_date]
job_date = F.date_trunc("DAY", F.current_timestamp())
window_end_ts = F.to_timestamp(job_date)
window_start_ts = window_end_ts - F.expr("INTERVAL 1 DAY")

# cf_processed_time for data records (same as job_date)
cf_processed_time = job_date

print("=" * 60)
print("BRONZE → SILVER → GOLD ETL Pipeline")
print("=" * 60)
print(f"catalog_name     = {catalog_name}")
print(f"job_date         = {job_date}")
print(f"window_start     = window_end - 1 day")
print("=" * 60)

## Helper Functions

In [None]:
def is_valid_company_id(schema_name: str) -> bool:
    """Check if schema name matches company ID format: 12 chars, lowercase alphanumeric."""
    return (
        len(schema_name) == 12 and
        schema_name.isalnum() and
        schema_name.islower()
    )

def discover_companies(catalog: str) -> list:
    """Discover all company schemas in the catalog."""
    try:
        databases = spark.catalog.listDatabases()
        companies = []
        for db in databases:
            # Schema format: catalog.company_id or just company_id
            # Extract the last part after splitting by '.'
            parts = db.name.split('.')
            # Handle both formats: "cloudfastener.xs22xw4aw73q" or "xs22xw4aw73q"
            if len(parts) >= 2 and parts[0] == catalog:
                schema_name = parts[1]
            elif len(parts) == 1:
                schema_name = parts[0]
            else:
                continue

            if is_valid_company_id(schema_name):
                companies.append(schema_name)
        return sorted(companies)
    except Exception as e:
        print(f"Error discovering companies: {e}")
        return []

def table_exists(full_name: str) -> bool:
    """Check if a table exists in the catalog."""
    try:
        return spark.catalog.tableExists(full_name)
    except Exception:
        return False

def normalize_finding_id(col):
    """Normalize finding ID: trim and convert empty to NULL."""
    return F.when(F.length(F.trim(col)) == 0, F.lit(None)).otherwise(F.trim(col))

def parse_iso8601_to_ts(col):
    """Parse ISO8601 timestamp string to Spark timestamp."""
    return F.to_timestamp(col)

## Discover Companies to Process

In [None]:
# Determine companies to process
if not company_id_param or company_id_param.upper() == "ALL":
    companies_to_process = discover_companies(catalog_name)
    print(f"Auto-discovery mode: Found {len(companies_to_process)} companies")
    if companies_to_process:
        print(f"Companies: {', '.join(companies_to_process)}")
else:
    # Single company mode
    if not is_valid_company_id(company_id_param):
        raise ValueError(f"Invalid company_id format: {company_id_param}. Must be 12 lowercase alphanumeric characters.")
    companies_to_process = [company_id_param]
    print(f"Single company mode: {company_id_param}")

if not companies_to_process:
    raise ValueError("No companies to process. Check catalog and schema names.")

print(f"\nTotal companies to process: {len(companies_to_process)}")
print("=" * 60)

## Process Each Company

Loop through each company and run the bronze → silver → gold pipeline.

In [None]:
# Track results
successful_companies = []
failed_companies = []
skipped_companies = []

for company_id in companies_to_process:
    print("\n" + "=" * 60)
    print(f"Processing company: {company_id}")
    print("=" * 60)

    try:
        # Define table names for this company
        asff_tbl = f"{catalog_name}.{company_id}.aws_securityhub_findings_1_0"
        ocsf_tbl = f"{catalog_name}.{company_id}.aws_securitylake_sh_findings_2_0"
        silver_tbl = f"{catalog_name}.{company_id}.aws_compliance_findings"
        gold_tbl = f"{catalog_name}.{company_id}.aws_standard_summary"

        print(f"ASFF bronze      = {asff_tbl}")
        print(f"OCSF bronze      = {ocsf_tbl}")
        print(f"Silver           = {silver_tbl}")
        print(f"Gold             = {gold_tbl}")
        print("-" * 60)

        # ============================================================
        # BRONZE → SILVER: Load and Transform
        # ============================================================

        # Check table existence
        asff_exists = table_exists(asff_tbl)
        ocsf_exists = table_exists(ocsf_tbl)

        print(f"ASFF exists? {asff_exists}")
        print(f"OCSF exists? {ocsf_exists}")

        if not asff_exists and not ocsf_exists:
            print(f"⚠️  Neither bronze table exists for {company_id}. Skipping...")
            skipped_companies.append((company_id, "No bronze tables"))
            continue

        sources = []

        # Load ASFF data
        if asff_exists:
            df_asff_raw = (
                spark.table(asff_tbl)
                .where(
                    (F.col("product_name") == "Security Hub") &
                    (F.col("cf_processed_time") >= window_start_ts) &
                    (F.col("cf_processed_time") < window_end_ts)
                )
            )
            asff_count = df_asff_raw.count()
            print(f"ASFF rows in window: {asff_count}")
            if asff_count > 0:
                sources.append(("ASFF", df_asff_raw))

        # Load OCSF data
        if ocsf_exists:
            df_ocsf_raw = (
                spark.table(ocsf_tbl)
                .where(
                    (F.col("metadata.product.name") == "Security Hub") &
                    (F.col("cf_processed_time") >= window_start_ts) &
                    (F.col("cf_processed_time") < window_end_ts)
                )
            )
            ocsf_count = df_ocsf_raw.count()
            print(f"OCSF rows in window: {ocsf_count}")
            if ocsf_count > 0:
                sources.append(("OCSF", df_ocsf_raw))

        if len(sources) == 0:
            print(f"⚠️  No rows found in window for {company_id}. Skipping...")
            skipped_companies.append((company_id, "No data in window"))
            continue

        # Transform Functions
        def transform_asff(df):
            return (
                df.select(
                    normalize_finding_id(F.col("finding_id")).alias("finding_id"),
                    parse_iso8601_to_ts(F.col("updated_at")).alias("finding_modified_time"),
                    F.to_timestamp(F.lit(cf_processed_time)).alias("cf_processed_time"),
                    F.when(F.upper(F.col("workflow.Status")) == "NEW", "New")
                     .when(F.upper(F.col("workflow.Status")) == "NOTIFIED", "In Progress")
                     .when(F.upper(F.col("workflow.Status")) == "SUPPRESSED", "Suppressed")
                     .when(F.upper(F.col("workflow.Status")) == "RESOLVED", "Resolved")
                     .otherwise(F.col("workflow.Status"))
                     .alias("finding_status"),
                    F.col("aws_account_id").cast("string").alias("account_id"),
                    F.col("finding_region").cast("string").alias("region_id"),
                    F.expr("compliance.AssociatedStandards[0].StandardsId").cast("string").alias("standard_id"),
                    F.col("compliance.SecurityControlId").cast("string").alias("control_id"),
                    F.col("compliance.Status").cast("string").alias("compliance_status"),
                    F.col("severity.Label").cast("string").alias("severity"),
                    F.col("cf_processed_time").alias("_bronze_processed_time"),
                    F.lit(1).alias("_preference")
                )
            )

        def transform_ocsf(df):
            return (
                df.select(
                    normalize_finding_id(F.col("finding_info.uid")).alias("finding_id"),
                    parse_iso8601_to_ts(F.col("finding_info.modified_time_dt")).alias("finding_modified_time"),
                    F.to_timestamp(F.lit(cf_processed_time)).alias("cf_processed_time"),
                    F.col("status").cast("string").alias("finding_status"),
                    F.col("cloud.account.uid").cast("string").alias("account_id"),
                    F.col("cloud.region").cast("string").alias("region_id"),
                    F.expr("compliance.standards[0]").cast("string").alias("standard_id"),
                    F.col("compliance.control").cast("string").alias("control_id"),
                    F.col("compliance.status").cast("string").alias("compliance_status"),
                    F.col("severity").cast("string").alias("severity"),
                    F.col("cf_processed_time").alias("_bronze_processed_time"),
                    F.lit(0).alias("_preference")
                )
            )

        # Transform and union
        canonical_dfs = []
        for src, df_raw in sources:
            if src == "ASFF":
                out = transform_asff(df_raw)
            elif src == "OCSF":
                out = transform_ocsf(df_raw)
            else:
                continue

            out = out.withColumn("finding_id", normalize_finding_id(F.col("finding_id"))) \
                     .where(F.col("finding_id").isNotNull())
            canonical_dfs.append(out)

        if not canonical_dfs:
            print(f"⚠️  No valid findings after filtering for {company_id}. Skipping...")
            skipped_companies.append((company_id, "No valid findings"))
            continue

        df_union = canonical_dfs[0]
        for d in canonical_dfs[1:]:
            df_union = df_union.unionByName(d, allowMissingColumns=True)

        print(f"Union rows: {df_union.count()}")

        # Deduplicate
        w = Window.partitionBy("finding_id").orderBy(
            F.col("finding_modified_time").desc_nulls_last(),
            F.col("_preference").desc(),
            F.col("_bronze_processed_time").desc_nulls_last()
        )

        df_winners = (
            df_union
            .withColumn("_rn", F.row_number().over(w))
            .where(F.col("_rn") == 1)
        )

        winner_count = df_winners.count()
        print(f"Winners after dedup: {winner_count} rows")

        # Prepare for merge (drop internal columns including _preference)
        df_stage = (
            df_winners
            .drop("_rn", "_preference", "_bronze_processed_time")
        )

        df_stage.createOrReplaceTempView("stg_silver")

        # MERGE into silver (no source_preference column)
        spark.sql(f"""
            MERGE INTO {silver_tbl} t
            USING stg_silver s
               ON t.finding_id = s.finding_id
            WHEN MATCHED AND (
                   t.finding_modified_time IS NULL
                OR s.finding_modified_time > t.finding_modified_time
            ) THEN UPDATE SET
                t.cf_processed_time     = s.cf_processed_time,
                t.finding_modified_time = s.finding_modified_time,
                t.finding_status        = s.finding_status,
                t.account_id            = s.account_id,
                t.region_id             = s.region_id,
                t.standard_id           = s.standard_id,
                t.control_id            = s.control_id,
                t.compliance_status     = s.compliance_status,
                t.severity              = s.severity
            WHEN NOT MATCHED THEN INSERT (
                finding_id, cf_processed_time, finding_modified_time,
                finding_status, account_id, region_id, standard_id,
                control_id, compliance_status, severity
            ) VALUES (
                s.finding_id, s.cf_processed_time, s.finding_modified_time,
                s.finding_status, s.account_id, s.region_id, s.standard_id,
                s.control_id, s.compliance_status, s.severity
            )
        """)

        print("✓ Silver MERGE completed")

        # ============================================================
        # SILVER → GOLD: Aggregation
        # ============================================================

        # Load silver data
        silver = (
            spark.table(silver_tbl)
            .where(F.col("cf_processed_time") == F.to_timestamp(job_date))
            .withColumn("compliance_status", F.upper("compliance_status"))
            .withColumn(
                "severity",
                F.when(F.col("severity").isNull(), "unclassified")
                 .otherwise(F.lower("severity"))
            )
        )

        silver_count = silver.count()
        print(f"Silver rows for aggregation: {silver_count}")

        if silver_count == 0:
            print(f"⚠️  No silver data for {company_id}. Skipping gold aggregation...")
            skipped_companies.append((company_id, "No silver data for job date"))
            continue

        # Control-level aggregation
        control_key = ["cf_processed_time", "account_id", "region_id", "standard_id", "control_id"]

        controls = (
            silver
            .groupBy(*control_key)
            .agg(
                F.max(F.when(F.col("compliance_status") == "FAILED", 1).otherwise(0)).alias("has_failed"),
                F.min(F.when(F.col("compliance_status") == "PASSED", 1).otherwise(0)).alias("all_passed"),
                F.max(F.when(F.col("compliance_status").isin("WARNING", "NOT_AVAILABLE"), 1).otherwise(0)).alias("has_unknown"),
                F.max("severity").alias("severity")
            )
            .withColumn(
                "control_status",
                F.when(F.col("has_failed") == 1, "FAILED")
                 .when(F.col("all_passed") == 1, "PASSED")
                 .otherwise("UNKNOWN")
            )
        )

        # Severity-level aggregation
        std_key = ["cf_processed_time", "account_id", "region_id", "standard_id"]

        severity_agg = (
            controls
            .groupBy(*std_key, "severity")
            .agg(
                F.countDistinct("control_id").alias("total"),
                F.sum(F.when(F.col("control_status") == "PASSED", 1).otherwise(0)).cast("int").alias("passed")
            )
            .withColumn(
                "score",
                F.round(
                    F.when(F.col("total") > 0, F.col("passed") * 100.0 / F.col("total"))
                     .otherwise(0.0),
                    2
                )
            )
        )

        # Standard-level aggregation
        standards = (
            severity_agg
            .groupBy(*std_key)
            .agg(
                F.sum("total").alias("total"),
                F.sum("passed").alias("passed"),
                F.collect_list(
                    F.struct(
                        F.col("severity").alias("level"),
                        "score",
                        F.struct("total", "passed").alias("controls")
                    )
                ).alias("controls_by_severity")
            )
            .withColumn(
                "score",
                F.round(
                    F.when(F.col("total") > 0, F.col("passed") * 100.0 / F.col("total"))
                     .otherwise(0.0),
                    2
                )
            )
            .select(
                *std_key,
                F.struct(
                    F.col("standard_id").alias("std"),
                    "score",
                    F.struct("total", "passed").alias("controls"),
                    "controls_by_severity"
                ).alias("standard_summary")
            )
        )

        # Account/region summary
        gold_key = ["cf_processed_time", "account_id", "region_id"]

        overall = (
            controls
            .groupBy(*gold_key)
            .agg(
                F.countDistinct(F.struct("standard_id", "control_id")).alias("total_rules"),
                F.sum(F.when(F.col("control_status") == "PASSED", 1).otherwise(0)).cast("int").alias("total_passed")
            )
            .withColumn(
                "control_pass_rate",
                F.when(F.col("total_rules") > 0, F.col("total_passed") / F.col("total_rules"))
                 .otherwise(0.0)
            )
        )

        gold = (
            overall
            .join(
                standards.groupBy(*gold_key)
                         .agg(F.collect_list("standard_summary").alias("standards_summary")),
                gold_key
            )
        )

        gold_count = gold.count()
        print(f"Gold summary rows: {gold_count}")

        if gold_count > 0:
            gold.createOrReplaceTempView("gold_updates")

            spark.sql(f"""
                MERGE INTO {gold_tbl} t
                USING gold_updates s
                   ON t.cf_processed_time = s.cf_processed_time
                  AND t.account_id = s.account_id
                  AND t.region_id = s.region_id
                WHEN MATCHED THEN UPDATE SET *
                WHEN NOT MATCHED THEN INSERT *
            """)

            print("✓ Gold MERGE completed")

        print(f"✓ Successfully processed {company_id}")
        successful_companies.append(company_id)

    except Exception as e:
        print(f"❌ ERROR processing {company_id}: {str(e)}")
        import traceback
        traceback.print_exc()
        failed_companies.append((company_id, str(e)))
        continue

print("\n" + "=" * 60)
print("ETL Pipeline Summary")
print("=" * 60)
print(f"Total companies: {len(companies_to_process)}")
print(f"Successful: {len(successful_companies)}")
print(f"Skipped: {len(skipped_companies)}")
print(f"Failed: {len(failed_companies)}")

if successful_companies:
    print(f"\n✓ Successful: {', '.join(successful_companies)}")

if skipped_companies:
    print(f"\n⚠️  Skipped:")
    for comp_id, reason in skipped_companies:
        print(f"  - {comp_id}: {reason}")

if failed_companies:
    print(f"\n❌ Failed:")
    for comp_id, error in failed_companies:
        print(f"  - {comp_id}: {error[:100]}...")

print("=" * 60)