GOLD LAYER SCRIPTS

patient data load script

In [0]:
from delta.tables import DeltaTable
from pyspark.sql import functions as F
from datetime import datetime

load_ts = datetime.now().isoformat()

# Read from silver and add load timestamp
silver_patient = (
    spark.table("silver_fhir_patient")
    .withColumn("load_timestamp", F.lit(load_ts))
)

try:
    dim_patient = DeltaTable.forName(spark, "gold_dim_patient")
    dim_patient_df = dim_patient.toDF()
except:
    dim_patient_df = None

if dim_patient_df:
    # Join current records for comparison
    joined = (
        silver_patient.alias("s")
        .join(
            dim_patient_df.filter("current_flag = true").alias("d"),
            F.col("s.id") == F.col("d.patient_id"),
            "left"
        )
    )
    
    # Identify new or changed records
    changes = joined.filter(
        F.col("d.patient_id").isNull() |
        (F.col("s.gender") != F.col("d.gender")) |
        (F.col("s.birthDate") != F.col("d.birthDate"))
    ).select("s.*")
else:
    changes = silver_patient

# Proceed only if changes exist
if changes.count() > 0:
    if dim_patient_df:
        # Close out existing active records
        closing = (
            dim_patient_df.filter("current_flag = true")
            .join(changes, dim_patient_df.patient_id == changes.id)
            .withColumn("effective_end_date", F.lit(load_ts))
            .withColumn("current_flag", F.lit(False))
        )

        # New SCD2 inserts with new surrogate key
        new_rows = (
            changes
            .withColumnRenamed("id", "patient_id")
            .withColumn("patient_sk", F.monotonically_increasing_id())
            .withColumn("effective_start_date", F.lit(load_ts))
            .withColumn("effective_end_date", F.lit(None).cast("string"))
            .withColumn("current_flag", F.lit(True))
        )

        # Final dataframe for overwrite
        final_df = (
            dim_patient_df
            .unionByName(closing.select(dim_patient_df.columns), allowMissingColumns=True)
            .unionByName(new_rows.select(dim_patient_df.columns), allowMissingColumns=True)
        )
    else:
        # First-time load
        final_df = (
            changes
            .withColumnRenamed("id", "patient_id")
            .withColumn("patient_sk", F.monotonically_increasing_id())
            .withColumn("effective_start_date", F.lit(load_ts))
            .withColumn("effective_end_date", F.lit(None).cast("string"))
            .withColumn("current_flag", F.lit(True))
        )

    # Overwrite gold_dim_patient with new version
    final_df.write.format("delta").mode("overwrite").saveAsTable("gold_dim_patient")


observation data loading script

In [0]:
from delta.tables import DeltaTable
from pyspark.sql import functions as F
from datetime import datetime

# Set the load timestamp
load_ts = datetime.now().isoformat()

# Read Silver Layer Observation table
silver_obs = (
    spark.table("silver_fhir_observation")
    .withColumn("load_timestamp", F.lit(load_ts))
)

# Read Dimension Patient table (only active rows)
dim_patient_active = (
    spark.table("gold_dim_patient")
    .filter(F.col("current_flag") == True)
    .select("patient_id", "patient_sk")
)

# Join observations to dimension patients
fact_obs = (
    silver_obs.alias("obs")
    .join(
        dim_patient_active.alias("dim"),
        F.col("obs.subject_reference") == F.col("dim.patient_id"),
        "left"
    )
)

# Extract useful nested fields from arrays
fact_obs_enriched = (
    fact_obs
    .withColumn("category_code", F.expr("category[0].coding[0].code"))
    .withColumn("category_display", F.expr("category[0].coding[0].display"))
    .withColumn("category_system", F.expr("category[0].coding[0].system"))
    .withColumn("interpretation_code", F.expr("interpretation[0].coding[0].code"))
    .withColumn("interpretation_display", F.expr("interpretation[0].coding[0].display"))
    .withColumn("interpretation_system", F.expr("interpretation[0].coding[0].system"))
    .withColumn("performer_reference", F.expr("performer[0].reference"))
    .withColumn("meta_profile", F.expr("meta.profile[0]"))
    .withColumnRenamed("id", "observation_id")
)

# Define allowed statuses for validation
allowed_statuses = ["final", "registered", "preliminary"]

# Apply Data Quality Checks
fact_obs_dq = (
    fact_obs_enriched
    .filter(F.col("observation_id").isNotNull())
    .filter(F.col("patient_sk").isNotNull())
    .filter(F.col("status").isin(allowed_statuses))
    .withColumn("effectiveDateTime_parsed", F.to_timestamp("effectiveDateTime"))
    .filter(F.col("effectiveDateTime_parsed").isNotNull())
    .filter(F.col("code_code").isNotNull() & F.col("code_system").isNotNull())
    .filter(
        F.col("value_quantity_value").isNull() | (F.col("value_quantity_value") >= 0)
    )
    .dropDuplicates(["observation_id"])
    .withColumn("fact_observation_sk", F.monotonically_increasing_id())
)

# Select final columns for the fact table
fact_obs_final = (
    fact_obs_dq
    .select(
        "fact_observation_sk",
        "observation_id",
        "patient_sk",
        "status",
        "code_code",
        "code_display",
        "code_system",
        "category_code",
        "category_display",
        "category_system",
        "effectiveDateTime",
        "performer_reference",
        "interpretation_code",
        "interpretation_display",
        "interpretation_system",
        "value_quantity_value",
        "value_quantity_unit",
        "resourceType",
        "load_timestamp"
    )
)

# Perform Delta MERGE to avoid duplicates
try:
    # Check if the table exists
    delta_table = DeltaTable.forName(spark, "gold_fact_observation")

    # Perform Merge
    delta_table.alias("tgt").merge(
        fact_obs_final.alias("src"),
        "tgt.observation_id = src.observation_id"
    ).whenNotMatchedInsertAll().execute()

except:
    # Table does not exist, so create it
    fact_obs_final.write.format("delta").mode("overwrite").saveAsTable("gold_fact_observation")

