In [None]:
# Databricks notebook: Process input CSV files, separate valid/invalid records, write valid to Delta table, log failed records
"""
This notebook processes input CSV files from S3, validates required columns, writes valid records to a Delta table partitioned by client/date/zip5, and writes failed records to a separate S3 location. Failed record counts are logged for each client.
"""

# Import required libraries
import dbutils
from pyspark.sql import SparkSession
from pyspark.sql.functions import input_file_name, col, lit, udf, sha2, substring, regexp_replace
from datetime import datetime
import re
import logging
from pyspark.sql.types import StringType
import boto3
import os

# Get file path from Databricks widget
dbutils.widgets.text("filepath", "", "S3 File Path")
filepath = dbutils.widgets.get("filepath")

# Set output locations and date
today = datetime.utcnow().strftime("%Y/%m/%d")
processed_bucket = "s3a://radiant-graph-delta-table/customer_data_by_client_date_zip"
failed_bucket = "s3a://radiant-graph-input-failed"

# Set up logger for failed record counts
logger = logging.getLogger("FailedRecordLogger")
logger.setLevel(logging.INFO)

try:
    # Create Spark session with Delta support
    spark = SparkSession.builder \
        .appName("ProcessInputFiles") \
        .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") \
        .config("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog") \
        .getOrCreate()

    # Read input CSV file from S3
    df = spark.read.csv(filepath, header=True, inferSchema=True)\
        .withColumn("source_file", input_file_name())

    # Extract client name from file path using regex
    def extract_client_name(path):
        match = re.search(r"client/([^/]+)/", path)
        return match.group(1) if match else "unknown"

    extract_client_name_udf = udf(extract_client_name, StringType())
    df = df.withColumn("client_name", extract_client_name_udf(col("source_file")))
    df = df.withColumn("date", lit(today))

    # Define required columns for validation
    required_columns = [
        "member_id", "first_name", "last_name", "dob", "gender", "phone", "email", "zip5", "plan_id"
     ]

    # Separate invalid records (missing required columns)
    invalid_df = df.filter(
        " OR ".join([f"{col} IS NULL" for col in required_columns])
    )
    valid_df = df.subtract(invalid_df)

    # Compliance transformation for valid records
    compliant_df = valid_df.select(
        sha2(col("member_id"), 256).alias("member_id_hash"),
        lit("REDACTED").alias("first_name"),
        lit("REDACTED").alias("last_name"),
        substring(col("dob"), 1, 4).alias("birth_year"),
        col("gender"),
        regexp_replace(col("phone"), r"\d{3}-\d{3}", "***-***").alias("phone_masked"),
        regexp_replace(col("email"), r"[^@]+", "user").alias("email_masked"),
        substring(col("zip5"), 1, 3).alias("zip3"),
        col("plan_id"),
        col("client_name"),
        col("date"),
        col("zip5")
    )

    # Write compliant records to Delta table, partitioned and compressed
    compliant_df.write.format("delta") \
        .mode("overwrite") \
        .option("compression", "zstd") \
        .partitionBy("client_name", "date", "zip5") \
        .save(processed_bucket)

    # Write failed records to S3 and log counts per client
    total_records = df.count()
    for client in invalid_df.select("client_name").distinct().rdd.flatMap(lambda x: x).collect():
        client_failed_df = invalid_df.filter(f"client_name = '{client}'")
        failed_count = client_failed_df.count()
        if failed_count > 0:
            client_failed_df.write.mode("overwrite").csv(
                f"{failed_bucket}/client/{client}/{today}/failed_records.csv",
                header=True
            )
            logger.info(f"Client: {client}, Date: {today}, Failed Records: {failed_count}, Total Records Processed: {total_records}")

    # Stop Spark session
    spark.stop()
except Exception as e:
    logger.error(f"ERROR: Data ingestion failed for file {filepath}: {str(e)}")
    # Trigger SNS notification
    try:
        sns_topic_arn = os.environ.get("SNS_TOPIC_ARN")
        sns_client = boto3.client("sns")
        message = f"ERROR: Data ingestion failed for file {filepath}: {str(e)}"
        subject = "Radiant Graph Data Ingestion Failure"
        if sns_topic_arn:
            sns_client.publish(TopicArn=sns_topic_arn, Message=message, Subject=subject)
    except Exception as sns_e:
        logger.error(f"ERROR: Failed to send SNS notification: {str(sns_e)}")
    raise
