In [0]:
from pyspark.sql.functions import *

In [0]:
catalog = "workspace"

source_schema = "silver"

source_object = "silver_bookings"

cdc_column = "modifiedDate"

backdated_refresh = ""

fact_table = f"{catalog}.{source_schema}.{source_object}"

target_schema = "gold"

target_object = "FactBookings"

fact_key_cols = ["DimPassengersKey","DimFlightsKey","DimAirportsKey","booking_date"]


In [0]:
dimensions = [
    {
        "table": f"{catalog}.{target_schema}.DimPassengers",
        "alias": "DimPassengers",
        "join_keys": [("passenger_id", "passenger_id")]  
    },
    {
        "table": f"{catalog}.{target_schema}.DimFlights",
        "alias": "DimFlights",
        "join_keys": [("flight_id", "flight_id")]  
    {
        "table": f"{catalog}.{target_schema}.DimAirports",
        "alias": "DimAirports",
        "join_keys": [("airport_id", "airport_id")]  
    },
]



fact_columns = ["amount","booking_date","modifiedDate"]

In [0]:
if len(backdated_refresh) == 0:
  
  
  if spark.catalog.tableExists(f"{catalog}.{target_schema}.{target_object}"):

    last_load = spark.sql(f"SELECT max({cdc_column}) FROM workspace.{target_schema}.{target_object}").collect()[0][0]
    
  else:
    last_load = "1900-01-01 00:00:00"


else:
  last_load = backdated_refresh


In [0]:
def generate_fact_query_incremental(fact_table, dimensions, fact_columns, cdc_column, processing_date):
    fact_alias = "f"
    
    
    select_cols = [f"{fact_alias}.{col}" for col in fact_columns]

    join_clauses = []
    for dim in dimensions:
        table_full = dim["table"]
        alias = dim["alias"]
        table_name = table_full.split('.')[-1]
        surrogate_key = f"{alias}.{table_name}Key"
        select_cols.append(surrogate_key)

        
        on_conditions = [
            f"{fact_alias}.{fk} = {alias}.{dk}" for fk, dk in dim["join_keys"]
        ]
        join_clause = f"LEFT JOIN {table_full} {alias} ON " + " AND ".join(on_conditions)
        join_clauses.append(join_clause)

    
    select_clause = ",\n    ".join(select_cols)
    joins = "\n".join(join_clauses)

    
    where_clause = f"{fact_alias}.{cdc_column} >= DATE('{last_load}')"

    
    query = f"""
SELECT
    {select_clause}
FROM {fact_table} {fact_alias}
{joins}
WHERE {where_clause}
""".strip()

    return query


In [0]:
query = generate_fact_query_incremental(fact_table, dimensions, fact_columns, cdc_column, last_load)

In [0]:
df_fact = spark.sql(query)

In [0]:
fact_key_cols_str = " AND ".join([f"src.{col} = trg.{col}" for col in fact_key_cols])

'src.DimPassengersKey = trg.DimPassengersKey AND src.DimFlightsKey = trg.DimFlightsKey AND src.DimAirportsKey = trg.DimAirportsKey AND src.booking_date = trg.booking_date'

In [0]:
from delta.tables import DeltaTable

if spark.catalog.tableExists(f"{catalog}.{target_schema}.{target_object}"):

    dlt_obj = DeltaTable.forName(spark, f"{catalog}.{target_schema}.{target_object}")
    dlt_obj.alias("trg").merge(df_fact.alias("src"), fact_key_cols_str)\
                        .whenMatchedUpdateAll(condition = f"src.{cdc_column} >= trg.{cdc_column}")\
                        .whenNotMatchedInsertAll()\
                        .execute()

else: 

    df_fact.write.format("delta")\
            .mode("append")\
            .saveAsTable(f"{catalog}.{target_schema}.{target_object}")
