### Necessary Imports

In [1]:
import os
import logging
from pathlib import Path
from typing import Optional
from datetime import datetime
from dotenv import load_dotenv

import polars as pl

from helper_utils import log_step


logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

load_dotenv(override=True) 

True

### MERGE HELPER FUNCTIONS
- Provides functions to define how to: 
    - apply joins
    - prepare the source
    - assign surrogate_keys
    - assign unknown_keys
    - handle data-loading
    - perform the  merge-operation

In [2]:
POLARS_DWH = Path(os.getenv("POLARS_DWH"))
PARQUET_FILES_DIR = Path(os.getenv("PARQUET_FILES_DIR"))

gold_parquet = PARQUET_FILES_DIR/'gold_layer'
gold_parquet.mkdir(parents=True, exist_ok=True)

In [3]:
# ------------------ Utility: apply joins ------------------
def apply_joins(base: pl.DataFrame, joins: list[tuple[pl.DataFrame, str | list[str], str]]) -> pl.DataFrame:
    """
    Apply multiple joins on a base DataFrame.

    Operations:
    - Sequentially join with provided DataFrames using given keys and join types.
    """
    for df, on, how in joins:
        if df is None:
            raise ValueError(f"Tried joining with None on={on}, how={how}")
        base = base.join(df, on=on, how=how)
    return base


In [4]:
# ------------------ Prepare Source ------------------
def prepare_src(
    df_src: pl.DataFrame,
    key_col: str,
    attr_cols: Optional[list[str]],
    gold_parquet: Path,
    is_composite: bool,
    df_src_initial: Optional[pl.DataFrame],
    extra_cols: Optional[list[str]] = None,
) -> tuple[pl.DataFrame, list[str], list[str]]:
    """
    Prepare the source DataFrame:
    - For simple dims: keep natural key col only.
    - For composite dims: join surrogate keys from parent dims.
    - For hierarchical dims: keep natural key col + join surrogate keys from parent dims.
    - Optionally, if present, join extra columns back from the original staging.
    """
    # log_step("INSIDE PREPARE_SRC")
    skey_cols, final_tbl_common_skeys = [], []

    # --- Case 1: Simple dimension (single key only) ---
    if attr_cols is None:
        df_src = (
            df_src.select(key_col)
                  .drop_nulls()
                  .unique(subset=[key_col])
                  .filter(pl.col(key_col) != "Unknown")
        )
        skey_cols = [key_col]

    # --- Case 2: Composite/ Hierarchical dimension (join with parent dims) ---
    else:
        for col in attr_cols:
            dim_file = gold_parquet / f"gold_dim_{col}.parquet"
            logger.info(f"Looking for parent dim: {dim_file}")

            # Always keep base key for non-composite
            if not is_composite:
                skey_cols.append(key_col)
                final_tbl_common_skeys.append(key_col)

            if dim_file.exists():
                # Use surrogate key from parent
                skey = f"{col}_skey"
                skey_cols.append(skey)
                final_tbl_common_skeys.append(skey)

                dim_df = pl.read_parquet(dim_file)

                # always keep skey, conditionally keep col
                cols = [skey] + ([col] if col in dim_df.columns else [])
                dim_df = dim_df.select(cols)

                # Add the skey col to the src tbl, to create foreign keys
                join_key = col if col in df_src.columns else f"{col}_skey"
                df_src = df_src.join(dim_df, on=join_key, how="left")

            else:
                # Fallback if parent dim missing i.e. If no surrogate key exists yet → just use the raw column as natural key
                skey_cols.append(col)

        # Deduplicate and sort
        skey_cols = list(dict.fromkeys(skey_cols))
        final_tbl_common_skeys = list(dict.fromkeys(final_tbl_common_skeys))
        df_src = df_src.select(skey_cols)

        # Optional filter → only if the column exists
        if not is_composite and df_src.schema.get(key_col) == pl.Utf8:
            df_src = df_src.filter(pl.col(key_col) != "Unknown")

        df_src = df_src.unique()

    # --- Re-join extra columns if needed ---
    if extra_cols:
        if not is_composite:
            join_cols = [key_col] + extra_cols
            df_src = df_src.join(df_src_initial.select(join_cols), on=key_col, how="left")
        else:
            join_cols = final_tbl_common_skeys + extra_cols
            df_src = df_src.join(df_src_initial.select(join_cols), on=final_tbl_common_skeys, how="left")

    # log_step("AFTER PREPARE_SRC", df_src)
    return df_src, skey_cols, final_tbl_common_skeys


In [5]:
# ------------------ Assign Surrogate Keys ------------------
def assign_surrogate_keys(
    df: pl.DataFrame,
    key_col: str,
    start_offset: int = 1,
    is_scd2: bool = True,
    batch_id: int = 1
) -> pl.DataFrame:
    """
    Assign surrogate keys and metadata columns.

    Operations:
    - Add surrogate key column.
    - Add SCD1 or SCD2 metadata fields.
    """
    # log_step("INSIDE ASSIGN_SURROGATE", df)
    now = datetime.now().replace(microsecond=0)
    assigned_key_col = key_col[:-3] if key_col.endswith("_id") else key_col

    # Surrogate key
    df = df.with_row_index(f"{assigned_key_col}_skey", offset=start_offset)
    df = df.with_columns(pl.col(f"{assigned_key_col}_skey").cast(pl.Int64))

    # Metadata cols
    if is_scd2:
        df = df.with_columns([
            pl.lit(1).alias("is_active"),
            pl.lit(now).alias("start_date"),
            pl.lit(None).cast(pl.Date).alias("end_date"),
            pl.lit(1).alias("is_current"),
            pl.lit(batch_id).alias("batch_id"),
            pl.lit(now).alias("load_timestamp"),
        ])
    else:
        df = df.with_columns([
            pl.lit(1).alias("is_active"),
            pl.lit(batch_id).alias("batch_id"),
            pl.lit(now).alias("load_timestamp"),
        ])

    return df


In [6]:
# ------------------ Add Unknown Row ------------------
def add_unknown_row(
    df_final: pl.DataFrame,
    key_col: str,
    is_scd2: bool = True,
    batch_id: int = 1,
) -> pl.DataFrame:
    """
    Add an 'Unknown' row for simple dims (not composites).

    Operations:
    - Generate next surrogate key.
    - Populate with default values and metadata.
    """
    now = datetime.now().replace(microsecond=0)
    assigned_key_col = key_col[:-3] if key_col.endswith("_id") else key_col
    next_skey = int(df_final[f"{assigned_key_col}_skey"].max()) + 1 if not df_final.is_empty() else 1

    base_data = {
        f"{assigned_key_col}_skey": [next_skey],
        key_col: ["Unknown"],
        "is_active": [1],
        "batch_id": [batch_id],
        "load_timestamp": [now],
    }

    if is_scd2:
        base_data.update({
            "start_date": [now],
            "end_date": [None],
            "is_current": [1],
        })

    unknown_row = pl.DataFrame(base_data).cast(df_final.schema)
    return unknown_row


In [7]:
# ------------------ Initial Load ------------------
def handle_initial_load(
    df_src: pl.DataFrame,
    key_col: str,
    attr_cols: Optional[list[str]],
    is_scd2: bool,
    batch_id: int,
) -> pl.DataFrame:
    """
    Handles the case when target table is empty (initial load).

    Operations:
    - Assign surrogate keys.
    - Add 'Unknown' row if applicable.
    """
    # log_step("INSIDE INITIAL LOAD")

    # 1. Assign surrogate keys
    df_final = assign_surrogate_keys(
        df_src, key_col, start_offset=1, is_scd2=is_scd2, batch_id=batch_id
    )
    # log_step("DF_FINAL after ASSIGN_SURROGATE", df_final)

    # 2. Add "Unknown" row (only for single-col/hierarchy dims, not composite)
    if attr_cols is None:
        unknown_row = add_unknown_row(df_final, key_col, is_scd2=is_scd2, batch_id=batch_id)
        
        # --- 3. Final concat ---
        df_final = pl.concat([df_final, unknown_row], how="vertical")
        # log_step("Added UNKNOWN row", df_final)

    # --- 4. Return ---
    return df_final


In [8]:
# ------------------ Incremental Load ------------------
def handle_incremental_load(
    df_src: pl.DataFrame,
    df_tgt: pl.DataFrame,
    key_col: str,
    attr_cols: Optional[list[str]],
    extra_cols: Optional[list[str]],
    is_scd2: bool,
    final_tbl_common_skeys: list[str],
    skey_cols: list[str],
    assigned_key_col: str,
    batch_id: int,
) -> pl.DataFrame:
    """
    Handles incremental load (when target has existing data).
    Flow: Expire old rows → Keep common rows → Add new rows → Preserve 'Unknown' row
    """
    # log_step("INSIDE INCREMENTAL LOAD")
    now = datetime.now().replace(microsecond=0)

    # ------ 1. ------ 
    # --- Category C-1: Rows exists only in target, not in source → expire them (not drop) - (set end_date, mark inactive) ---
    if is_scd2:
        # --- Use only the current active rows ---
        active_rows = df_tgt.filter(pl.col("is_current") == 1)
        # --- Fetch old expired rows ---
        old_expired_rows = df_tgt.filter(pl.col("is_current") == 0)

        # --- Fetch newly expired rows ---
        # prevent Unknown from expiring if attr_cols is None
        if attr_cols is None:
            newly_expired = (
                active_rows.filter(pl.col(key_col) != "Unknown")
                .join(df_src, on=skey_cols, how="anti")
            )
        else:
            newly_expired = active_rows.join(df_src, on=skey_cols, how="anti")

        if not newly_expired.is_empty():
            newly_expired = newly_expired.with_columns([
                pl.lit(0).alias("is_current"),
                pl.lit(now).alias("end_date"),
            ])
    else:
        # SCD1: no expiration logic
        active_rows = df_tgt
        old_expired_rows = df_tgt.head(0)
        newly_expired = df_tgt.head(0)

    # --- Category C-2: Common rows - Rows exists in both source and target → keep them as is ---
    if extra_cols:
        common_rows = active_rows.join(df_src, on=final_tbl_common_skeys, how="inner").select(active_rows.columns)
    else:
        common_rows = active_rows.join(df_src, on=skey_cols, how="inner")

    # --- Category C-3: New rows - Rows exists only in the source, not in target ---
    new_rows = df_src.join(active_rows, on=skey_cols, how="anti")

    # If new rows exists: Assign surrogate keys and other cols to them
    if not new_rows.is_empty():
        start_offset = int(df_tgt[f"{assigned_key_col}_skey"].max()) + 1
        new_rows = assign_surrogate_keys(new_rows, key_col, start_offset, is_scd2, batch_id)
    else:
        new_rows = df_tgt.head(0)

    # --- 2. Preserve "Unknown" row ---
    unknown_row = (
        df_tgt.filter(pl.col(key_col) == "Unknown") if (attr_cols is None) else df_tgt.head(0)
    )

    # --- Schema alignment (including extra cols) ---
    common_schema = df_tgt.schema
    if extra_cols:
        for col in extra_cols:
            if col not in common_schema:
                common_schema[col] = df_src.schema[col]

    all_rows = [old_expired_rows, newly_expired, common_rows, new_rows, unknown_row]
    all_rows = [row.cast(common_schema, strict=False) for row in all_rows]

    # --- 3. Final concat ---
    df_final = pl.concat(all_rows, how="vertical")
    
    # --- 4. Return ---
    return df_final


In [9]:
# ------------------ Main Merge Function ------------------
def merge_fn(
    df_src: pl.DataFrame,
    df_tgt: Optional[pl.DataFrame],
    key_col: str,
    attr_cols: Optional[list[str]],
    is_scd2: bool,
    extra_cols: Optional[list[str]] = None,
    batch_id: int = 1,
) -> pl.DataFrame:
    """
    Main merge function for initial and incremental loads.

    Operations:
    - Prepare source DataFrame.
    - Run initial load if target empty, else incremental load.
    """

    log_step("INSIDE MAIN")
    logger.info(f"attr_cols : {attr_cols}")
    logger.info(f"extra_cols : {extra_cols}")
    logger.info(f"df_src.columns : {df_src.columns}")
    
    # --- Prepare required variables ---
    assigned_key_col = key_col[:-3] if key_col.endswith("_id") else key_col
    is_composite = True if (attr_cols and key_col not in df_src.columns) else False

    # --- Prepare df_src ---
    df_src, skey_cols, final_tbl_common_skeys = prepare_src(
        df_src, key_col, attr_cols, gold_parquet,
        is_composite,
        df_src_initial=df_src,
        extra_cols=extra_cols
    )

    # Case 1: Initial Load
    if df_tgt is None or df_tgt.is_empty():
        return handle_initial_load(df_src, key_col, attr_cols, is_scd2, batch_id)

    # Case 2: Incremental Load
    return handle_incremental_load(
        df_src, df_tgt, key_col, attr_cols, extra_cols, is_scd2, 
        final_tbl_common_skeys, skey_cols, assigned_key_col, batch_id
    )
