# GOLD LAYER

In [47]:
import polars as pl
from datetime import date, timedelta, datetime
import os
import sys
from pathlib import Path
from dotenv import load_dotenv

sys.path.append(os.path.abspath(".."))
import udfs as udf

In [48]:
load_dotenv(override=True) 

POLARS_DWH = Path(os.getenv("POLARS_DWH"))
print(POLARS_DWH)

staging_dir = POLARS_DWH/'staging_layer'
gold_dir = POLARS_DWH/'gold_layer'
print(staging_dir, gold_dir, sep="\n")

/home/sapna.choudhary/Data-Engineering-Training/Polars_DWH
/home/sapna.choudhary/Data-Engineering-Training/Polars_DWH/staging_layer
/home/sapna.choudhary/Data-Engineering-Training/Polars_DWH/gold_layer


In [49]:
def load_dim_date(start_date: str = "1954-08-18", end_date: str = "2025-07-30") -> pl.DataFrame:
    # generate date range
    start = datetime.strptime(start_date, "%Y-%m-%d").date()
    end = datetime.strptime(end_date, "%Y-%m-%d").date()
    
    num_days = (end - start).days + 1
    date_list = [start + timedelta(days=i) for i in range(num_days)]
    
    df = pl.DataFrame({"date": date_list})
    
    df = (
        df.with_columns([
            pl.col("date").dt.year().alias("year"),
            pl.col("date").dt.month().alias("month"),
            pl.col("date").dt.day().alias("day"),
            pl.col("date").dt.weekday().alias("weekday"),   # 0=Mon, 6=Sun
            pl.col("date").dt.strftime("%A").alias("weekday_name"),
            pl.when(pl.col("date").dt.weekday().is_in([5, 6]))
              .then(pl.lit(1)).otherwise(pl.lit(0))
              .alias("is_weekend"),
            ("Q" + pl.col("date").dt.quarter().cast(pl.Utf8)).alias("fiscal_quarter")
        ])
    )
    
    # enforce types for base df
    df = df.cast({
        "date": pl.Date,
        "year": pl.Int32,
        "month": pl.Int32,
        "day": pl.Int32,
        "weekday": pl.Int32,
        "weekday_name": pl.Utf8,
        "is_weekend": pl.Int32,
        "fiscal_quarter": pl.Utf8,
    })
    
    # add special mapping row with same dtypes
    df_unknown = pl.DataFrame({
        "date": [date(1900,1,1)],
        "year": [0],
        "month": [0],
        "day": [0],
        "weekday": [0],
        "weekday_name": ["Unknown"],
        "is_weekend": [0],
        "fiscal_quarter": ["Q0"]
    }).cast(df.schema) 
    
    df_final = pl.concat([df, df_unknown], how="vertical")
    
    return df_final.sort("date")

In [111]:
def merge_upsert_dim(
    df_src: pl.DataFrame,
    df_tgt: pl.DataFrame | None,
    key_col: str,
    batch_id: int,
    keep_unknown: bool = False,
    unknown_value: str = "Unknown"
) -> pl.DataFrame:
    """
    Generalized UPSERT (MERGE) for dimension tables in Polars.

    Rules:
      - If key exists in both → keep
      - If key exists only in source → insert
      - If key exists only in target → delete (unless keep_unknown=True)
      - If keep_unknown=True → always ensure an 'Unknown' row exists
    """
    # --- 1. Clean + distinct source
    df_src = df_src.select(key_col).drop_nulls().unique(subset=[key_col])
    # print("\n============\ndf_src : ", df_src)

    # --- 2. Handle first load (target missing)
    if df_tgt is None or df_tgt.is_empty():
        # filter out Unknown if flagged
        base = df_src.filter(pl.col(key_col) != unknown_value) if keep_unknown else df_src
        # print("\n============\nbase : ", base)

        df_final = (
            base
            .with_row_index(f"{key_col}_skey", offset=1)
            .with_columns([
                pl.col(f"{key_col}_skey").cast(pl.Int64),
                pl.lit(1).alias("is_active"),
                pl.lit(batch_id).alias("batch_id"),
                pl.lit(datetime.now()).alias("load_timestamp"),
            ])
        )
        # print("\n============\ndf_final : ", df_final)

        # append Unknown row if required
        if keep_unknown:
            next_skey = int(df_final[f"{key_col}_skey"].max()) + 1 if not df_final.is_empty() else 1
            unknown_row = pl.DataFrame({
                # f"{key_col}_skey": [next_skey],
                f"{key_col}_skey": pl.Series([next_skey], dtype=pl.Int64),
                key_col: [unknown_value],
                "is_active": [1],
                "batch_id": [batch_id],
                "load_timestamp": [datetime.now()],
            })
            unknown_row = unknown_row.cast(df_final.schema)
            # print("\n============\nunknown_row : ", unknown_row)

            df_final = pl.concat([df_final, unknown_row], how="vertical")
        
            # print("\n============\ndf_final : ", df_final)
        return df_final.sort(f"{key_col}_skey")

    # --- 3. Incremental merge (target already exists) ---

    # INSERT: new rows
    new_rows = df_src.filter(~pl.col(key_col).is_in(df_tgt.select(key_col)[key_col].implode()))
    if keep_unknown:
        new_rows = new_rows.filter(pl.col(key_col) != unknown_value)
        # print("\n============\ndf_src : ", df_src)

    if not new_rows.is_empty():
        start_offset = int(df_tgt[f"{key_col}_skey"].max()) + 1
        new_rows = (
            new_rows
            .with_row_index(f"{key_col}_skey", offset=start_offset)
            .with_columns([
                pl.col(f"{key_col}_skey").cast(pl.Int64),
                pl.lit(1).alias("is_active"),
                pl.lit(batch_id).alias("batch_id"),
                pl.lit(datetime.now()).alias("load_timestamp"),
            ])
        )
        # print("\n============\nnew_rows : ", new_rows)
    else:
        new_rows = df_tgt.head(0)
        # print("\n============\nnew_rows : ", new_rows)

    # KEEP: common keys
    # common = df_tgt.filter(pl.col(key_col).is_in(df_src[key_col]))
    common = df_tgt.filter(
        pl.col(key_col).is_in(df_src.select(key_col)[key_col].implode())
    )

    # KEEP Unknown row (if flagged)
    keep_unknown_row = df_tgt.filter(pl.col(key_col) == unknown_value) if keep_unknown else df_tgt.head(0)

    # Final merge result
    df_final = pl.concat([common, new_rows, keep_unknown_row], how="vertical") \
                 .unique(subset=[key_col], keep="first") \
                 .sort(f"{key_col}_skey")

    return df_final


In [None]:
table_name = 'customer'
key_cols = {"gender": False, "marital_status": False, "customer_type": False, "account_status": False, "region": True}

for col, keep_unknown in key_cols.items():
    print("\n--- Running merge for:", col, "keep_unknown:", keep_unknown)

    src_file_name = staging_dir / f"staging_{table_name}.parquet"
    df_src = pl.read_parquet(src_file_name)

    # load gold table if exists
    tgt_file_name = gold_dir / f"gold_dim_{col}.parquet"
    df_tgt = pl.read_parquet(tgt_file_name) if os.path.exists(tgt_file_name) else None

    # ##### TEST FILTER #####
    # if col == 'gender':
    #     df_src = df_src.filter(pl.col(col) != 'Female')
    #     print(df_src['gender'].unique())
    # if col == 'region':
    #     df_src = df_src.filter(pl.col(col) != 'Africa')
    #     print(df_src['region'].unique())

    # merge
    df_tgt = merge_upsert_dim(df_src, df_tgt, col, batch_id=1, keep_unknown=keep_unknown)
    print(df_tgt)

    # save
    df_tgt.write_parquet(tgt_file_name)

gender False
shape: (4, 5)
┌─────────────┬─────────┬───────────┬──────────┬────────────────────────────┐
│ gender_skey ┆ gender  ┆ is_active ┆ batch_id ┆ load_timestamp             │
│ ---         ┆ ---     ┆ ---       ┆ ---      ┆ ---                        │
│ i64         ┆ str     ┆ i32       ┆ i32      ┆ datetime[μs]               │
╞═════════════╪═════════╪═══════════╪══════════╪════════════════════════════╡
│ 1           ┆ Other   ┆ 1         ┆ 1        ┆ 2025-09-02 16:12:50.984558 │
│ 2           ┆ Male    ┆ 1         ┆ 1        ┆ 2025-09-02 16:12:50.984558 │
│ 4           ┆ Unknown ┆ 1         ┆ 1        ┆ 2025-09-02 16:12:50.984558 │
│ 5           ┆ Female  ┆ 1         ┆ 1        ┆ 2025-09-02 16:17:35.144812 │
└─────────────┴─────────┴───────────┴──────────┴────────────────────────────┘
marital_status False
shape: (3, 5)
┌─────────────────────┬────────────────┬───────────┬──────────┬────────────────────────────┐
│ marital_status_skey ┆ marital_status ┆ is_active ┆ batch_id ┆ l

In [52]:
df_dim_marital_status_parquet = pl.read_parquet("gold_dim_marital_status.parquet")
df_dim_marital_status_parquet

marital_status_skey,marital_status,is_active,batch_id,load_timestamp
i64,str,i32,i32,datetime[μs]
1,"""Unknown""",1,1,2025-09-02 16:12:50.988122
2,"""Married""",1,1,2025-09-02 16:12:50.988122
3,"""Single""",1,1,2025-09-02 16:12:50.988122


In [53]:
table_name = 'product'
# key_col = ['gender', 'marital_status', 'customer_type', 'account_status']
key_cols = {"brand_tier": False, "brand_name": False, "brand_country": False}

for col, keep_unknown in key_cols.items():
    print(col, keep_unknown)
    src_file_name = staging_dir / f"staging_{table_name}.parquet"

    df_src = pl.read_parquet(src_file_name)

    # load gold table if exists
    tgt_file_name = gold_dir / f"gold_dim_{col}.parquet"

    if os.path.exists(tgt_file_name):
        df_tgt = pl.read_parquet(tgt_file_name)
    else:
        df_tgt = None

    # ##### TEST FILTER #####
    # if col == 'gender':
    #     df_src = df_src.filter(pl.col(col) != 'Female')
    #     print(df_src['gender'].unique())
    # if col == 'region':
    #     df_src = df_src.filter(pl.col(col) != 'Africa')
    #     print(df_src['region'].unique())

    # merge
    df_tgt = merge_upsert_dim(df_src, df_tgt, col, batch_id=1, keep_unknown=keep_unknown)
    print(df_tgt)    

    # save
    df_tgt.write_parquet(tgt_file_name)

brand_tier False


shape: (4, 5)
┌─────────────────┬────────────┬───────────┬──────────┬────────────────────────────┐
│ brand_tier_skey ┆ brand_tier ┆ is_active ┆ batch_id ┆ load_timestamp             │
│ ---             ┆ ---        ┆ ---       ┆ ---      ┆ ---                        │
│ i64             ┆ str        ┆ i32       ┆ i32      ┆ datetime[μs]               │
╞═════════════════╪════════════╪═══════════╪══════════╪════════════════════════════╡
│ 1               ┆ Unknown    ┆ 1         ┆ 1        ┆ 2025-09-02 16:12:51.088061 │
│ 2               ┆ Standard   ┆ 1         ┆ 1        ┆ 2025-09-02 16:12:51.088061 │
│ 3               ┆ Economy    ┆ 1         ┆ 1        ┆ 2025-09-02 16:12:51.088061 │
│ 4               ┆ Premium    ┆ 1         ┆ 1        ┆ 2025-09-02 16:12:51.088061 │
└─────────────────┴────────────┴───────────┴──────────┴────────────────────────────┘
brand_name False
shape: (14, 5)
┌─────────────────┬────────────┬───────────┬──────────┬────────────────────────────┐
│ brand_name_skey ┆

In [54]:
table_name = 'order'
# key_col = ['gender', 'marital_status', 'customer_type', 'account_status']
key_cols = {"payment_source": False, "lead_type": False}

for col, keep_unknown in key_cols.items():
    print(col, keep_unknown)
    src_file_name = staging_dir / f"staging_{table_name}.parquet"

    df_src = pl.read_parquet(src_file_name)

    # load gold table if exists
    tgt_file_name = gold_dir / f"gold_dim_{col}.parquet"

    if os.path.exists(tgt_file_name):
        df_tgt = pl.read_parquet(tgt_file_name)
    else:
        df_tgt = None

    # ##### TEST FILTER #####
    # if col == 'gender':
    #     df_src = df_src.filter(pl.col(col) != 'Female')
    #     print(df_src['gender'].unique())
    # if col == 'region':
    #     df_src = df_src.filter(pl.col(col) != 'Africa')
    #     print(df_src['region'].unique())

    # merge
    df_tgt = merge_upsert_dim(df_src, df_tgt, col, batch_id=1, keep_unknown=keep_unknown)
    print(df_tgt)    

    # save
    df_tgt.write_parquet(tgt_file_name)

payment_source False
shape: (7, 5)
┌─────────────────────┬────────────────┬───────────┬──────────┬────────────────────────────┐
│ payment_source_skey ┆ payment_source ┆ is_active ┆ batch_id ┆ load_timestamp             │
│ ---                 ┆ ---            ┆ ---       ┆ ---      ┆ ---                        │
│ i64                 ┆ str            ┆ i32       ┆ i32      ┆ datetime[μs]               │
╞═════════════════════╪════════════════╪═══════════╪══════════╪════════════════════════════╡
│ 1                   ┆ Debit Card     ┆ 1         ┆ 1        ┆ 2025-09-02 16:12:51.248975 │
│ 2                   ┆ Paypal         ┆ 1         ┆ 1        ┆ 2025-09-02 16:12:51.248975 │
│ 3                   ┆ Net Banking    ┆ 1         ┆ 1        ┆ 2025-09-02 16:12:51.248975 │
│ 4                   ┆ Upi            ┆ 1         ┆ 1        ┆ 2025-09-02 16:12:51.248975 │
│ 5                   ┆ Credit Card    ┆ 1         ┆ 1        ┆ 2025-09-02 16:12:51.248975 │
│ 6                   ┆ Unknown    

In [63]:
import polars as pl
from datetime import datetime

# def merge_scd2_dim_brand(df_src: pl.DataFrame, df_tgt: pl.DataFrame | None, batch_id: int) -> pl.DataFrame:
def merge_scd2_dim(
    df_src: pl.DataFrame,
    df_tgt: pl.DataFrame | None,
    key_col: str,                # e.g. "brand_skey"
    attr_cols: list[str],        # e.g. ["brand_tier_skey", "brand_name_skey", "brand_country_skey"]
    batch_id: int
) -> pl.DataFrame:
    """
    SCD-2 Merge for dim_brand using Polars.
    
    Rules:
      - If brand_skey exists in both and attributes match -> keep as is
      - If brand_skey exists only in source -> insert as new row
      - If brand_skey exists only in target -> expire old row
      - If brand_skey exists in both but attributes differ -> expire old + insert new version
    """

    today = datetime.now().date()

    if df_tgt is None or df_tgt.is_empty():
        # first load → insert all rows as current
        return (
            df_src
            .with_columns([
                pl.lit(1).alias("is_active"),
                pl.lit(today).alias("start_date"),
                pl.lit(None).cast(pl.Date).alias("end_date"),
                pl.lit(1).alias("is_current"),
                pl.lit(batch_id).alias("batch_id"),
                pl.lit(datetime.now()).alias("load_timestamp"),
            ])
        )
    # print("\n============\ndf_tgt : ", df_tgt)
    # -------------------
    # 1. Expire rows where attributes changed or missing in source
    # -------------------
    tgt_active = df_tgt.filter(pl.col("is_current") == 1)
    # print("\n============\ntgt_active : ", tgt_active)

    # Join source with active target
    joined = tgt_active.join(
        df_src,
        on="brand_skey",
        how="full",
        suffix="_src"
    )
    # print("\n============\njoined : ", joined)

    # Case A: Row exists in both but attributes differ → expire old
    # expire_changed = joined.filter(
    #     (pl.col("brand_skey").is_not_null()) & (
    #         (pl.col("brand_tier_skey") != pl.col("brand_tier_skey_src")) |
    #         (pl.col("brand_name_skey") != pl.col("brand_name_skey_src")) |
    #         (pl.col("brand_country_skey") != pl.col("brand_country_skey_src"))
    #     )
    # ).select("brand_skey")
    cond_changed = None
    for c in attr_cols:
        diff = pl.col(c) != pl.col(f"{c}_src")
        cond_changed = diff if cond_changed is None else (cond_changed | diff)

    expire_changed = joined.filter(
        (pl.col(key_col).is_not_null()) & cond_changed
    ).select(key_col)
    # print("\n============\nexpire_changed : ", expire_changed)

    # Case B: Row exists in target but missing in source → expire old
    expire_missing = joined.filter(pl.col("brand_tier_skey_src").is_null()).select("brand_skey")
    # print("\n============\nexpire_missing : ", expire_missing)

    to_expire = pl.concat([expire_changed, expire_missing]).unique()
    # print("\n============\nto_expire : ", to_expire)
    
    expire_keys = to_expire["brand_skey"].to_list()
    df_tgt = df_tgt.with_columns(
        # pl.when(pl.col("brand_skey").is_in(to_expire["brand_skey"]))
        #   .then(0).otherwise(pl.col("is_active")).alias("is_active"),
        # pl.when(pl.col("brand_skey").is_in(to_expire["brand_skey"]))
        #   .then(0).otherwise(pl.col("is_current")).alias("is_current"),
        # pl.when(pl.col("brand_skey").is_in(to_expire["brand_skey"]))
        #   .then(today).otherwise(pl.col("end_date")).alias("end_date")
        
        pl.when(pl.col("brand_skey").is_in(expire_keys))
            .then(0).otherwise(pl.col("is_active")).alias("is_active"),
        pl.when(pl.col("brand_skey").is_in(expire_keys))
            .then(0).otherwise(pl.col("is_current")).alias("is_current"),
        pl.when(pl.col("brand_skey").is_in(expire_keys))
            .then(today).otherwise(pl.col("end_date")).alias("end_date")
    )
    # print("\n============\ndf_tgt : ", df_tgt)
    # print("\n============\ndf_tgt : ", df_tgt.to_pandas())

    # -------------------
    # 2. Insert new versions for changed or new keys
    # -------------------
    # Case C: Brand-new rows
    # new_rows = df_src.filter(~pl.col("brand_skey").is_in(tgt_active["brand_skey"]))
    new_rows = df_src.filter(~pl.col("brand_skey").is_in(tgt_active["brand_skey"].to_list()))
    # print("\n============\nnew_rows : ", new_rows)

    # Case D: New version for changed rows
    # changed_rows = df_src.filter(pl.col("brand_skey").is_in(expire_changed["brand_skey"]))
    changed_rows = df_src.filter(pl.col("brand_skey").is_in(expire_changed["brand_skey"].to_list()))
    # print("\n============\nchanged_rows : ", changed_rows)

    inserts = pl.concat([new_rows, changed_rows]).with_columns([
        pl.lit(today).alias("start_date"),
        pl.lit(None).cast(pl.Date).alias("end_date"),
        pl.lit(1).alias("is_current"),
        pl.lit(1).alias("is_active"),
        pl.lit(batch_id).alias("batch_id"),
        pl.lit(datetime.now()).alias("load_timestamp"),
    ]).sort("brand_skey")
    # print("\n============\ninserts : ", inserts)

    # print("\n============\n df_tgt : ", df_tgt.columns)
    # print("\n============\n inserts : ", inserts.columns)
    
    # align schemas to avoid Int32/Int64 mismatch
    common_schema = {col: pl.Int64 for col, dtype in zip(df_tgt.columns, df_tgt.dtypes) if dtype in (pl.Int32, pl.Int64)}

    df_tgt = df_tgt.cast(common_schema)
    inserts = inserts.cast(common_schema)

    # -------------------
    # 3. Final result
    # -------------------
    df_final = pl.concat([df_tgt, inserts], how="vertical").sort("brand_skey")
    # print("\n============\ndf_final : ", df_final)

    return df_final


In [64]:
table_name = "product"
key_col = "brand_skey"
attr_cols = ["brand_tier_skey", "brand_name_skey", "brand_country_skey"]

src_file_name = staging_dir / f"staging_{table_name}.parquet"
df_src = pl.read_parquet(src_file_name)

tgt_file_name = gold_dir / f"gold_dim_{table_name}.parquet"
df_tgt = pl.read_parquet(tgt_file_name) if os.path.exists(tgt_file_name) else None

df_tgt = merge_scd2_dim(df_src, df_tgt, key_col, attr_cols, batch_id=1)

df_tgt.write_parquet(tgt_file_name)


In [None]:
df_dim_parquet = pl.read_parquet(f"gold_dim_{table_name}.parquet")
df_dim_parquet

In [None]:
# import polars as pl
# from datetime import datetime

# def merge_scd2_dim_brand(df_src: pl.DataFrame, df_tgt: pl.DataFrame | None, batch_id: int) -> pl.DataFrame:
#     """
#     SCD-2 Merge for dim_brand using Polars.
    
#     Rules:
#       - If brand_skey exists in both and attributes match -> keep as is
#       - If brand_skey exists only in source -> insert as new row
#       - If brand_skey exists only in target -> expire old row
#       - If brand_skey exists in both but attributes differ -> expire old + insert new version
#     """

#     today = datetime.now().date()

#     if df_tgt is None or df_tgt.is_empty():
#         # first load → insert all rows as current
#         return (
#             df_src
#             .with_columns([
#                 pl.lit(1).alias("is_active"),
#                 pl.lit(today).alias("start_date"),
#                 pl.lit(None).cast(pl.Date).alias("end_date"),
#                 pl.lit(1).alias("is_current"),
#                 pl.lit(batch_id).alias("batch_id"),
#                 pl.lit(datetime.now()).alias("load_timestamp"),
#             ])
#         )
#     # print("\n============\ndf_tgt : ", df_tgt)
#     # -------------------
#     # 1. Expire rows where attributes changed or missing in source
#     # -------------------
#     tgt_active = df_tgt.filter(pl.col("is_current") == 1)
#     # print("\n============\ntgt_active : ", tgt_active)

#     # Join source with active target
#     joined = tgt_active.join(
#         df_src,
#         on="brand_skey",
#         how="full",
#         suffix="_src"
#     )
#     # print("\n============\njoined : ", joined)

#     # Case A: Row exists in both but attributes differ → expire old
#     expire_changed = joined.filter(
#         (pl.col("brand_skey").is_not_null()) & (
#             (pl.col("brand_tier_skey") != pl.col("brand_tier_skey_src")) |
#             (pl.col("brand_name_skey") != pl.col("brand_name_skey_src")) |
#             (pl.col("brand_country_skey") != pl.col("brand_country_skey_src"))
#         )
#     ).select("brand_skey")
#     # print("\n============\nexpire_changed : ", expire_changed)

#     # Case B: Row exists in target but missing in source → expire old
#     expire_missing = joined.filter(pl.col("brand_tier_skey_src").is_null()).select("brand_skey")
#     # print("\n============\nexpire_missing : ", expire_missing)

#     to_expire = pl.concat([expire_changed, expire_missing]).unique()
#     # print("\n============\nto_expire : ", to_expire)
    
#     expire_keys = to_expire["brand_skey"].to_list()
#     df_tgt = df_tgt.with_columns(
#         # pl.when(pl.col("brand_skey").is_in(to_expire["brand_skey"]))
#         #   .then(0).otherwise(pl.col("is_active")).alias("is_active"),
#         # pl.when(pl.col("brand_skey").is_in(to_expire["brand_skey"]))
#         #   .then(0).otherwise(pl.col("is_current")).alias("is_current"),
#         # pl.when(pl.col("brand_skey").is_in(to_expire["brand_skey"]))
#         #   .then(today).otherwise(pl.col("end_date")).alias("end_date")
        
#         pl.when(pl.col("brand_skey").is_in(expire_keys))
#             .then(0).otherwise(pl.col("is_active")).alias("is_active"),
#         pl.when(pl.col("brand_skey").is_in(expire_keys))
#             .then(0).otherwise(pl.col("is_current")).alias("is_current"),
#         pl.when(pl.col("brand_skey").is_in(expire_keys))
#             .then(today).otherwise(pl.col("end_date")).alias("end_date")
#     )
#     # print("\n============\ndf_tgt : ", df_tgt)
#     # print("\n============\ndf_tgt : ", df_tgt.to_pandas())

#     # -------------------
#     # 2. Insert new versions for changed or new keys
#     # -------------------
#     # Case C: Brand-new rows
#     # new_rows = df_src.filter(~pl.col("brand_skey").is_in(tgt_active["brand_skey"]))
#     new_rows = df_src.filter(~pl.col("brand_skey").is_in(tgt_active["brand_skey"].to_list()))
#     # print("\n============\nnew_rows : ", new_rows)

#     # Case D: New version for changed rows
#     # changed_rows = df_src.filter(pl.col("brand_skey").is_in(expire_changed["brand_skey"]))
#     changed_rows = df_src.filter(pl.col("brand_skey").is_in(expire_changed["brand_skey"].to_list()))
#     # print("\n============\nchanged_rows : ", changed_rows)

#     inserts = pl.concat([new_rows, changed_rows]).with_columns([
#         pl.lit(today).alias("start_date"),
#         pl.lit(None).cast(pl.Date).alias("end_date"),
#         pl.lit(1).alias("is_current"),
#         pl.lit(1).alias("is_active"),
#         pl.lit(batch_id).alias("batch_id"),
#         pl.lit(datetime.now()).alias("load_timestamp"),
#     ]).sort("brand_skey")
#     # print("\n============\ninserts : ", inserts)

#     # print("\n============\n df_tgt : ", df_tgt.columns)
#     # print("\n============\n inserts : ", inserts.columns)
    
#     # align schemas to avoid Int32/Int64 mismatch
#     common_schema = {col: pl.Int64 for col, dtype in zip(df_tgt.columns, df_tgt.dtypes) if dtype in (pl.Int32, pl.Int64)}

#     df_tgt = df_tgt.cast(common_schema)
#     inserts = inserts.cast(common_schema)

#     # -------------------
#     # 3. Final result
#     # -------------------
#     df_final = pl.concat([df_tgt, inserts], how="vertical").sort("brand_skey")
#     # print("\n============\ndf_final : ", df_final)

#     return df_final


In [None]:
# ### --- TEST CASE ---
# df_tgt = pl.DataFrame({
#     "brand_skey": [1, 2, 3],
#     "brand_tier_skey": [5, 6, 7],
#     "brand_name_skey": [10, 20, 30],
#     "brand_country_skey": [100, 200, 300],
#     "start_date": [datetime(2023,1,1).date()]*3,
#     "end_date": [None, None, None],
#     "is_current": [1, 1, 1],
#     "is_active": [1, 1, 1],
#     "batch_id": [0, 0, 0],
#     "load_timestamp": [datetime(2023,1,1)]*3,
# })

# df_src = pl.DataFrame({
#     "brand_skey": [1, 3, 4],
#     "brand_tier_skey": [5, 7, 8],
#     "brand_name_skey": [10, 50, 40],
#     "brand_country_skey": [100, 600, 400],
# })

# df_final = merge_scd2_dim_brand(df_src, df_tgt, batch_id=1)
# print(df_final)


In [88]:
import polars as pl
from datetime import datetime

# ---------- 1. Root-level (simple UPSERT) ----------
def merge_upsert_dim_chain(df_src, df_tgt, key_col, batch_id, keep_unknown=False, unknown_value="Unknown"):
    df_src = df_src.select(key_col).drop_nulls().unique()
    if df_tgt is None or df_tgt.is_empty():
        base = df_src if not keep_unknown else df_src.filter(pl.col(key_col) != unknown_value)
        df_final = (
            base
            .with_row_index(f"{key_col}_skey", offset=1)
            .with_columns([
                pl.col(f"{key_col}_skey").cast(pl.Int64),
                pl.lit(1).alias("is_active"),
                pl.lit(batch_id).alias("batch_id"),
                pl.lit(datetime.now()).alias("load_timestamp"),
            ])
        )
        if keep_unknown:
            unknown_row = pl.DataFrame({
                f"{key_col}_skey": [int(df_final[f"{key_col}_skey"].max()) + 1 if not df_final.is_empty() else 1],
                key_col: [unknown_value],
                "is_active": [1],
                "batch_id": [batch_id],
                "load_timestamp": [datetime.now()],
            }).cast(df_final.schema)
            df_final = pl.concat([df_final, unknown_row])
        return df_final.sort(f"{key_col}_skey")

    # --- incremental ---
    # new_rows = df_src.filter(~pl.col(key_col).is_in(df_tgt[key_col]))
    new_rows = df_src.filter(~pl.col(key_col).is_in(df_tgt[key_col].implode()))
    if not new_rows.is_empty():
        start_offset = int(df_tgt[f"{key_col}_skey"].max()) + 1
        new_rows = (
            new_rows
            .with_row_index(f"{key_col}_skey", offset=start_offset)
            .with_columns([
                pl.col(f"{key_col}_skey").cast(pl.Int64),
                pl.lit(1).alias("is_active"),
                pl.lit(batch_id).alias("batch_id"),
                pl.lit(datetime.now()).alias("load_timestamp"),
            ])
        )
    else:
        new_rows = df_tgt.head(0)

    # common = df_tgt.filter(pl.col(key_col).is_in(df_src[key_col]))
    common = df_tgt.filter(pl.col(key_col).is_in(df_src[key_col].implode()))
    return pl.concat([common, new_rows]).unique([key_col]).sort(f"{key_col}_skey")


# ---------- 2. Dependent levels (SCD2) ----------
def merge_scd2_dim_chain(df_src, df_tgt, key_col, attr_cols, batch_id):
    today = datetime.now().date()

    # if df_tgt is None or df_tgt.is_empty():
    #     return (
    #         df_src
    #         .with_columns([
    #             pl.lit(today).alias("start_date"),
    #             pl.lit(None).cast(pl.Date).alias("end_date"),
    #             pl.lit(1).alias("is_current"),
    #             pl.lit(1).alias("is_active"),
    #             pl.lit(batch_id).alias("batch_id"),
    #             pl.lit(datetime.now()).alias("load_timestamp"),
    #         ])
    #     )
    if df_tgt is None or df_tgt.is_empty():
        return (
            df_src
            .with_row_index(f"{key_col}_skey", offset=1)
            .with_columns([
                pl.col(f"{key_col}_skey").cast(pl.Int64),
                pl.lit(today).alias("start_date"),
                pl.lit(None).cast(pl.Date).alias("end_date"),
                pl.lit(1).alias("is_current"),
                pl.lit(1).alias("is_active"),
                pl.lit(batch_id).alias("batch_id"),
                pl.lit(datetime.now()).alias("load_timestamp"),
            ])
        )

    tgt_active = df_tgt.filter(pl.col("is_current") == 1)
    joined = tgt_active.join(df_src, on=key_col, how="full", suffix="_src")

    # --- detect changed rows ---
    cond_changed = None
    for c in attr_cols:
        diff = pl.col(c) != pl.col(f"{c}_src")
        cond_changed = diff if cond_changed is None else (cond_changed | diff)

    expire_changed = joined.filter((pl.col(key_col).is_not_null()) & cond_changed).select(key_col)
    expire_missing = joined.filter(pl.col(f"{attr_cols[0]}_src").is_null()).select(key_col)
    expire_keys = pl.concat([expire_changed, expire_missing]).unique()[key_col].to_list()

    df_tgt = df_tgt.with_columns([
        pl.when(pl.col(key_col).is_in(expire_keys)).then(0).otherwise(pl.col("is_active")).alias("is_active"),
        pl.when(pl.col(key_col).is_in(expire_keys)).then(0).otherwise(pl.col("is_current")).alias("is_current"),
        pl.when(pl.col(key_col).is_in(expire_keys)).then(today).otherwise(pl.col("end_date")).alias("end_date"),
    ])

    # new_rows = df_src.filter(~pl.col(key_col).is_in(tgt_active[key_col]))
    new_rows = df_src.filter(~pl.col(key_col).is_in(tgt_active[key_col].implode()))
    # changed_rows = df_src.filter(pl.col(key_col).is_in(expire_changed[key_col]))
    changed_rows = df_src.filter(pl.col(key_col).is_in(expire_changed[key_col].implode()))
    # inserts = pl.concat([new_rows, changed_rows]).with_columns([
    #     pl.lit(today).alias("start_date"),
    #     pl.lit(None).cast(pl.Date).alias("end_date"),
    #     pl.lit(1).alias("is_current"),
    #     pl.lit(1).alias("is_active"),
    #     pl.lit(batch_id).alias("batch_id"),
    #     pl.lit(datetime.now()).alias("load_timestamp"),
    # ])
    if not pl.concat([new_rows, changed_rows]).is_empty():
        start_offset = int(df_tgt[f"{key_col}_skey"].max()) + 1
        inserts = (
            pl.concat([new_rows, changed_rows])
            .with_row_index(f"{key_col}_skey", offset=start_offset)
            .with_columns([
                pl.col(f"{key_col}_skey").cast(pl.Int64),
                pl.lit(today).alias("start_date"),
                pl.lit(None).cast(pl.Date).alias("end_date"),
                pl.lit(1).alias("is_current"),
                pl.lit(1).alias("is_active"),
                pl.lit(batch_id).alias("batch_id"),
                pl.lit(datetime.now()).alias("load_timestamp"),
            ])
        )
    else:
        inserts = df_tgt.head(0)

    # return pl.concat([df_tgt, inserts]).sort(key_col)
    df_final = pl.concat([df_tgt, inserts])

    # Ensure surrogate key exists
    if f"{key_col}_skey" not in df_final.columns:
        df_final = (
            df_final
            .with_row_index(f"{key_col}_skey", offset=1)
            .with_columns(pl.col(f"{key_col}_skey").cast(pl.Int64))
        )

    return df_final.sort(f"{key_col}_skey")


In [112]:
tbl = "customer"
df_customer = pl.read_parquet(staging_dir / f"staging_{tbl}.parquet")

hierarchy = [
    ("region", "region", []),
    ("country", "country", ["region_skey"]),
    ("state", "state", ["country_skey"]),
    ("city", "city", ["state_skey"]),
    ("postal_code", "postal_code", ["city_skey"]),
]

parent_map = {
    "country": ("region", "region_skey"),
    "state": ("country", "country_skey"),
    "city": ("state", "state_skey"),
    "postal_code": ("city", "city_skey"),
}

for table_name, key_col, attr_cols in hierarchy:
    print(f"Processing {table_name}...")

    if not attr_cols:
        # Root level: region
        df_src = df_customer.select([key_col]).drop_nulls().unique()
    else:
        # Need to lookup parent surrogate key
        parent_name, parent_skey = parent_map[table_name]

        parent_file = gold_dir / f"gold_dim_{parent_name}.parquet"
        df_parent = pl.read_parquet(parent_file)

        df_src = (
            df_customer
            .select([key_col, parent_name])          # natural keys only
            .drop_nulls()
            .unique()
            .join(df_parent.select([parent_name, parent_skey]), on=parent_name, how="left")
            .select([key_col, parent_skey])
        )

    # Load gold if exists
    tgt_file_name = gold_dir / f"gold_dim_{table_name}.parquet"
    df_tgt = pl.read_parquet(tgt_file_name) if os.path.exists(tgt_file_name) else None

    # Merge
    if not attr_cols:
        df_tgt = merge_upsert_dim(df_src, df_tgt, key_col, batch_id=1)
    else:
        df_tgt = merge_scd2_dim_chain(df_src, df_tgt, key_col, attr_cols, batch_id=1)

    df_tgt.write_parquet(tgt_file_name)


Processing region...
Processing country...
Processing state...
Processing city...
Processing postal_code...


In [116]:
table_name = 'region'
df_dim_parquet = pl.read_parquet(f"gold_dim_{table_name}.parquet")
df_dim_parquet

region_skey,region,is_active,batch_id,load_timestamp
i64,str,i32,i32,datetime[μs]
1,"""South America""",1,1,2025-09-03 14:35:03.756456
2,"""Asia""",1,1,2025-09-03 14:35:03.756456
3,"""Europe""",1,1,2025-09-03 14:35:03.756456
4,"""North America""",1,1,2025-09-03 14:35:03.756456
5,"""Africa""",1,1,2025-09-03 14:35:03.756456


In [114]:
# ### --- TEST CASE ---
# # 1. Region
# df_region_src = pl.DataFrame({"region": ["Asia", "Europe"]})
# df_region_tgt = None
# df_region_final = merge_upsert_dim_chain(df_region_src, df_region_tgt, "region", batch_id=1)

# # 2. Country (depends on region)
# df_country_src = pl.DataFrame({"country": ["India", "Germany"], "region_skey": [1, 2]})
# df_country_tgt = None
# df_country_final = merge_scd2_dim_chain(df_country_src, df_country_tgt, "country", ["region_skey"], batch_id=1)

# # 3. State (depends on country)
# df_state_src = pl.DataFrame({"state": ["Karnataka", "Bavaria"], "country_skey": [1, 2]})
# df_state_tgt = None
# df_state_final = merge_scd2_dim_chain(df_state_src, df_state_tgt, "state", ["country_skey"], batch_id=1)

# print(df_region_src)
# print(df_region_final)
# print(df_country_src)
# print(df_country_final)
# print(df_state_src)
# print(df_state_final)


In [115]:
# ### --- TEST CASE ---
# # ✅ new batch with a change: "Europe" → "EU"
# # First load (initial batch)
# df_region_final = merge_upsert_dim_chain(
#     pl.DataFrame({"region": ["Asia", "Europe"]}),
#     None,   # no previous tgt
#     "region",
#     batch_id=1
# )

# # Second load (update batch)
# df_region_src_new = pl.DataFrame({"region": ["Asia", "EU"]})
# df_region_final_updated = merge_upsert_dim_chain(
#     df_region_src_new,
#     df_region_final,   # use previous tgt (with region_skey)
#     "region",
#     batch_id=2
# )

# print("\n=================\n", df_region_src_new)
# print(df_region_final)
# print(df_region_final_updated)

# # ✅ Suppose India’s region_skey changed from 1→3 (new region mapping)
# df_country_src_new = pl.DataFrame({"country": ["India", "Germany"], "region_skey": [3, 2]})
# # Initial load for country
# df_country_final = merge_scd2_dim_chain(
#     pl.DataFrame({"country": ["India", "Germany"], "region_skey": [1, 2]}),
#     None,
#     "country",
#     ["region_skey"],
#     batch_id=1
# )

# # Next batch (update)
# df_country_final_updated = merge_scd2_dim_chain(
#     df_country_src_new,
#     df_country_final,
#     "country",
#     ["region_skey"],
#     batch_id=2
# )


# print("\n=================\n", df_country_src_new)
# print(df_country_final)
# print(df_country_final_updated)

# # ✅ Rename Karnataka → Bengaluru
# df_state_src_new = pl.DataFrame({"state": ["Bengaluru", "Bavaria"], "country_skey": [1, 2]})
# # Initial load for state
# df_state_final = merge_scd2_dim_chain(
#     pl.DataFrame({"state": ["Karnataka", "Bavaria"], "country_skey": [1, 2]}),
#     None,
#     "state",
#     ["country_skey"],
#     batch_id=1
# )

# # Next batch (update)
# df_state_final_updated = merge_scd2_dim_chain(
#     df_state_src_new,
#     df_state_final,   # previous tgt
#     "state",
#     ["country_skey"],
#     batch_id=2
# )

# print("\n=================\n", df_state_src_new)
# print(df_state_final)
# print(df_state_final_updated)
