In [0]:
%run ../.././start_up 

In [0]:
logger = create_logger(notebook_name="landing_to_bronze", log_level="DEBUG")
logger.info("🚀 Initializing landing_to_bronze notebook")

# Extract frequently used config values into variables
catalog = pipeline_config["catalog"]
bronze_schema = pipeline_config["schemas"]["bronze"]
bronze_path = pipeline_config["paths"]["bronze_path"]
bronze_volume_path = pipeline_config["paths"]["bronze_volume_path"]
silver_schema = pipeline_config["schemas"]["silver"]
silver_path = pipeline_config["paths"]["silver_path"]
landing_schema= pipeline_config["schemas"]["landing"]
landing_path = pipeline_config["paths"]["landing_path"]
logs_schema = pipeline_config["schemas"]["logs"]
table_name = "treatments"
logger.info("Extracted frequently used config values into variables")

In [0]:

# --- Setup ---
spark.sql(f"CREATE SCHEMA IF NOT EXISTS {catalog}.{silver_schema}")
spark.sql(f"CREATE SCHEMA IF NOT EXISTS {catalog}.silver_errors")

# Paths and targets
bronze_input_path = f"{bronze_volume_path}/{table_name}"
silver_output_table = f"{catalog}.{silver_schema}.{table_name}"
error_output_table = f"{catalog}.silver_errors.{table_name}_errors"

In [0]:
# treatments_to_silver.py (SCD Type 2 - Final: Accurate Change Detection)
from pyspark.sql.functions import col, current_timestamp, lit, trim, upper, monotonically_increasing_id, when
from datetime import datetime
from delta.tables import DeltaTable


# Read CSV
def read_bronze_csv(path):
    logger.info(f"📥 Reading from Bronze: {path}")
    return spark.read.option("header", "true").csv(path)

# Rename Columns
def apply_column_mapping(df):
    logger.info("🔀 Applying column mappings")
    for old_col, new_col in column_mappings[table_name].items():
        if old_col in df.columns:
            df = df.withColumnRenamed(old_col, new_col)
    return df

# Enforce Schema
def enforce_data_types(df):
    logger.info("🧪 Enforcing data types")
    for col_name, dtype in table_config[table_name]["columns"].items():
        if col_name in df.columns:
            df = df.withColumn(col_name, col(col_name).cast(dtype))
    return df

# DQ Checks for treatments
def run_dq_checks(df):
    logger.info("✅ Running DQ checks for 'treatments'")
    dq_errors = None

    def append_dq(df_err, new_rows, reason):
        new_rows = new_rows.withColumn("dq_error", lit(reason))
        return new_rows if df_err is None else df_err.unionByName(new_rows)

    if "treatment_id" in df.columns:
        dq_errors = append_dq(dq_errors, df.filter(col("treatment_id").isNull()), "treatment_id is null")

    if "patient_id" in df.columns:
        dq_errors = append_dq(dq_errors, df.filter(col("patient_id").isNull()), "patient_id is null")

    if "treatment_type" in df.columns:
        dq_errors = append_dq(dq_errors, df.filter(trim(col("treatment_type")) == ""), "treatment_type is blank")

    return dq_errors

# Filter and write errors
def filter_and_store_errors(df, dq_errors):
    logger.info("🗑️ Filtering bad records")
    valid_df = df

    if dq_errors is not None:
        dq_errors = dq_errors.withColumn("dq_error_id", monotonically_increasing_id())
        dq_errors = dq_errors.withColumn("error_ts", current_timestamp())
        dq_errors.write.format("delta").option("mergeSchema", "true").mode("append").saveAsTable(error_output_table)
        logger.info(f"⚠️ {dq_errors.count()} DQ errors written to: {error_output_table}")

        valid_df = df.join(dq_errors.select("treatment_id").distinct(), on="treatment_id", how="left_anti")

    return valid_df

# Apply SCD Type 2
def apply_scd_type2(df):
    logger.info(f"🔁 Applying SCD Type 2 to: {silver_output_table}")

    if not spark.catalog.tableExists(silver_output_table):
        df = df.withColumn("valid_from", current_timestamp()) \
               .withColumn("valid_to", lit(None).cast("timestamp")) \
               .withColumn("is_current", lit(True))
        df.write.format("delta").saveAsTable(silver_output_table)
        logger.info(f"✅ Created new SCD Type 2 table: {silver_output_table}")
        return

    current_df = spark.table(silver_output_table).filter("is_current = true")
    join_keys = ["treatment_id"]
    compare_cols = [c for c in df.columns if c not in join_keys + ["valid_from", "valid_to", "is_current", "ingestion_date"]]

    source_df = df.alias("source")
    target_df = current_df.alias("target")

    join_cond = " AND ".join([f"source.{k} = target.{k}" for k in join_keys])
    change_expr = " OR ".join([f"target.{c} IS DISTINCT FROM source.{c}" for c in compare_cols])

    joined_df = source_df.join(target_df, on=join_keys, how="left")
    changed_df = joined_df.filter(f"{change_expr} OR target.{join_keys[0]} IS NULL").select("source.*")

    logger.info(f"📌 Changed/new rows: {changed_df.count()}")

    if changed_df.count() == 0:
        logger.info("✅ No changes found. Skipping SCD2 merge.")
        return

    changed_df = changed_df \
        .withColumn("valid_from", current_timestamp()) \
        .withColumn("valid_to", lit(None).cast("timestamp")) \
        .withColumn("is_current", lit(True))

    delta_table = DeltaTable.forName(spark, silver_output_table)

    delta_table.alias("target").merge(
        source=changed_df.alias("source"),
        condition=join_cond + " AND target.is_current = true"
    ).whenMatchedUpdate(
        condition=change_expr,
        set={
            "valid_to": "current_timestamp()",
            "is_current": "false"
        }
    ).execute()

    changed_df = changed_df.select(
        "treatment_id", "patient_id", "treatment_type", "description", "ingestion_date",
        "valid_from", "valid_to", "is_current"
    )

    changed_df.write.format("delta").option("mergeSchema", "true").mode("append").saveAsTable(silver_output_table)
    logger.info("✅ SCD2 merge completed: history updated and new rows inserted.")

# Execute pipeline
bronze_df = read_bronze_csv(bronze_input_path)
bronze_df = apply_column_mapping(bronze_df)
bronze_df = enforce_data_types(bronze_df)
if "ingestion_date" not in bronze_df.columns:
    bronze_df = bronze_df.withColumn("ingestion_date", current_timestamp())

dq_errors = run_dq_checks(bronze_df)
clean_df = filter_and_store_errors(bronze_df, dq_errors)
apply_scd_type2(clean_df)
