### Necessary Imports

In [1]:
import os, sys
import logging
from pathlib import Path
from dotenv import load_dotenv

import polars as pl
from sqlalchemy.engine import Engine

import config
from helper_utils import get_batch_id, load_cached_parquet

top_level = Path().resolve().parent
sys.path.append(str(top_level))
from db_utils import engine


logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

load_dotenv(override=True)

True

In [2]:
%run ./construct_merge.ipynb

### RUNNER FUNCTIONS
- Provides functions to: build_source -> run_merge_operation

In [3]:
POLARS_DWH = Path(os.getenv("POLARS_DWH"))
PARQUET_FILES_DIR = Path(os.getenv("PARQUET_FILES_DIR"))

staging_parquet = PARQUET_FILES_DIR/'staging_layer'
staging_parquet.mkdir(parents=True, exist_ok=True)

gold_parquet = PARQUET_FILES_DIR/'gold_layer'
gold_parquet.mkdir(parents=True, exist_ok=True)

In [4]:
# ------------------ Construct Source Dataframe ------------------
def build_source(tbl: str, staging_parquet: Path, gold_parquet: Path) -> pl.DataFrame:
    """
    - Create the source DataFrame using the staging layer’s table and the corresponding gold layer tables (with surrogate keys).
    - Apply the necessary column-level transformations.
    """
    props = config.TABLE_CONFIG[tbl]

    # --- Prepare list of tuples of all the joining tables, with 'on' and 'how' specified as 'joins' ---
    gold_parquet_joins = []
    for parquet_file, cols, on, how in props["joins"]:
        df_gold = load_cached_parquet(gold_parquet, parquet_file, cols)
        if df_gold is None:
            raise ValueError(f"Join source {parquet_file} returned None! Check path: {gold_parquet}/{parquet_file}")
        gold_parquet_joins.append((df_gold, on, how))

    # --- Get optional transform ---
    transform_fn_name = props.get("transform_fn")
    transform_fn = config.transform_fn_map.get(transform_fn_name, None)

    # --- Load staging ---
    df_staging = load_cached_parquet(staging_parquet, props["staging_file"])

    # --- Key column filtering ---
    if df_staging.schema[props["key_col"]] == pl.Utf8:
        df_staging = df_staging.filter(pl.col(props["key_col"]).is_not_null() & (pl.col(props["key_col"]) != "")).unique(props["key_col"])
    else:
        df_staging = df_staging.filter(pl.col(props["key_col"]).is_not_null() & (pl.col(props["key_col"]) > 0))

    # --- Apply joins ---
    df_final = apply_joins(df_staging, gold_parquet_joins)

    # --- Apply transform ---
    if transform_fn:
        df_final = transform_fn(df_final)

    # --- Final select & sort ---
    df_final = df_final.select(props["select_cols"])
    return df_final


In [5]:
# ------------------ Prepare Merge Execution ------------------
def run_merge(
    df_src: pl.DataFrame, 
    tgt_tbl: str, 
    engine: Engine, 
    schema_name: str
) -> pl.DataFrame:
    """
    - Retrieve and configure all required columns based on the config.
    - Invoke the core 'run merge' function.
    - Persist results in both the local gold layer (Parquet files) and the SQL Server database.
    - Prepare source DataFrame.
    - Run initial load if target empty, else incremental load.
    """
    # --- Fetch required attributes ---
    props = config.TABLE_PROPS.get(tgt_tbl, {})
    is_scd2 = props.get("is_scd2", False)
    is_dim  = props.get("is_dim", True)

    file_path = f"gold_dim_{tgt_tbl}.parquet" if is_dim else f"gold_fact_{tgt_tbl}.parquet"
    df_tgt = load_cached_parquet(gold_parquet, file_path)

    batch_id = get_batch_id(df_tgt)

    # --- Fetch & Prepare required columns and variables ---
    attr_cols = config.parent_map.get(tgt_tbl, None)
    extra_cols = config.extra_col_map.get(tgt_tbl, None)
    key_col   = config.key_col_map.get(tgt_tbl, tgt_tbl)

    assigned_key_col = key_col[:-3] if key_col.endswith("_id") else key_col
    is_composite = True if (attr_cols and key_col not in df_src.columns) else False
    
    # --- Logs ---
    log_step("INSIDE MAIN")
    logger.info(f"attr_cols : {attr_cols}")
    logger.info(f"extra_cols : {extra_cols}")
    logger.info(f"df_src.columns : {df_src.columns}")

    # --- Prepare df_src ---
    df_src, skey_cols, final_tbl_common_skeys = prepare_src(
        df_src, key_col, attr_cols, gold_parquet,
        is_composite,
        df_src_initial=df_src,
        extra_cols=extra_cols
    )

    # Case 1: Initial Load
    if df_tgt is None or df_tgt.is_empty():
        df_final = handle_initial_load(df_src, key_col, attr_cols, is_scd2, batch_id)

    # Case 2: Incremental Load
    else:
        df_final = handle_incremental_load(
            df_src, df_tgt, key_col, attr_cols, extra_cols, is_scd2, 
            final_tbl_common_skeys, skey_cols, assigned_key_col, batch_id
        )
    logger.info(f"df_final.columns : {df_final.columns}")

    # Save to parquet (local gold layer)
    df_final.write_parquet(gold_parquet / file_path)
    
    # Save to SQL Server (gold schema in DB)
    table_name = f"{schema_name}.{tgt_tbl}"
    df_final.write_database(
        table_name=table_name,
        connection=engine,
        if_table_exists="replace",
    )
    logger.info(f"Saved {tgt_tbl} into DB schema {schema_name}")

    return df_final


In [6]:
# ------------------ Build and Merge All ------------------
def build_and_merge_all(table_groups: dict) -> dict[str, pl.DataFrame]:
    """
    - Identify tables from the config as initial or final.
    - Call the respective build and merge functions with the appropriate parameters.
    """
    results = {}

    # --- INITIAL group ---
    for src_table, tgt_tbls in table_groups.get("initial", {}).items():
        src_file = staging_parquet / f"staging_{src_table}.parquet"
        df_src = pl.read_parquet(src_file)

        for tgt_tbl in tgt_tbls:
            logger.info(f"WORKING FOR INITIAL TABLE: {tgt_tbl}")
            results[tgt_tbl] = run_merge(df_src, tgt_tbl, engine, schema_name='gold')

    # --- Final group ---
    df_src_map = {}

    for tgt_tbl in table_groups.get("final", []):
        logger.info(f"WORKING FOR FINAL TABLE: {tgt_tbl}")

        df_src_map[tgt_tbl] = build_source(tgt_tbl, staging_parquet, gold_parquet)
        df_src = df_src_map[tgt_tbl]

        results[tgt_tbl] = run_merge(df_src, tgt_tbl, engine, schema_name='gold')
        
    return results
