In [11]:
import os
from pathlib import Path
from dotenv import load_dotenv

import urllib
import polars as pl
from sqlalchemy import create_engine

import config

load_dotenv(override=True) 


True

In [12]:
%run ./helper_utils.ipynb
%run ./merge.ipynb

#### RUNNER FUNCTIONS ##
- Provides functions to: build_source -> run_merge_operation

In [13]:
POLARS_DWH = Path(os.getenv("POLARS_DWH"))
PARQUET_FILES_DIR = Path(os.getenv("PARQUET_FILES_DIR"))

staging_parquet = PARQUET_FILES_DIR/'staging_layer'
staging_parquet.mkdir(parents=True, exist_ok=True)

gold_parquet = PARQUET_FILES_DIR/'gold_layer'
gold_parquet.mkdir(parents=True, exist_ok=True)


In [14]:
## --------------------------------------------------- ##
### Setup the db-connection
params = urllib.parse.quote_plus(
    "DRIVER={ODBC Driver 17 for SQL Server};"
    "SERVER=associatetraining.database.windows.net,1433;"
    "DATABASE=associatetraining;"
    "UID=training;"
    "PWD=dFyUT1#$rKIh26;"
)

engine = create_engine(f"mssql+pyodbc:///?odbc_connect={params}")


In [15]:
## --------------------------------------------------- ##
def build_source(tbl: str, staging_parquet: Path, gold_parquet: Path) -> pl.DataFrame:
    props = config.TABLE_CONFIG[tbl]

    # --- Prepare list of tuples of all the joining tables, with 'on' and 'how' specified as 'joins' ---
    gold_parquet_joins = []
    for parquet_file, cols, on, how in props["joins"]:
        df_gold = load_cached_parquet(gold_parquet, parquet_file, cols)
        if df_gold is None:
            raise ValueError(f"Join source {parquet_file} returned None! Check path: {gold_parquet}/{parquet_file}")
        gold_parquet_joins.append((df_gold, on, how))

    # --- Get optional transform ---
    transform_fn_name = props.get("transform_fn")
    transform_fn = config.transform_fn_map.get(transform_fn_name, None)

    # --- Load staging ---
    df_staging = load_cached_parquet(staging_parquet, props["staging_file"])

    # --- Key column filtering ---
    if df_staging.schema[props["key_col"]] == pl.Utf8:
        df_staging = df_staging.filter(pl.col(props["key_col"]).is_not_null() & (pl.col(props["key_col"]) != "")).unique(props["key_col"])
    else:
        df_staging = df_staging.filter(pl.col(props["key_col"]).is_not_null() & (pl.col(props["key_col"]) > 0))

    # --- Apply joins ---
    df_final = apply_joins(df_staging, gold_parquet_joins)

    # --- Apply transform ---
    if transform_fn:
        df_final = transform_fn(df_final)

    # --- Final select & sort ---
    df_final = df_final.select(props["select_cols"])
    return df_final


In [16]:
## --------------------------------------------------- ##
def run_merge(df_src: pl.DataFrame, tgt_tbl: str, engine, schema_name: str) -> pl.DataFrame:
    props = config.TABLE_PROPS.get(tgt_tbl, {})
    is_scd2 = props.get("is_scd2", False)
    is_dim  = props.get("is_dim", True)

    file_path = f"gold_dim_{tgt_tbl}.parquet" if is_dim else f"gold_fact_{tgt_tbl}.parquet"
    df_tgt = load_cached_parquet(gold_parquet, file_path)
    batch_id = get_batch_id(df_tgt)

    attr_cols = config.parent_map.get(tgt_tbl, None)
    extra_cols = config.extra_col_map.get(tgt_tbl, None)
    key_col   = config.key_col_map.get(tgt_tbl, tgt_tbl)

    df_final = merge_fn(
        df_src=df_src,
        df_tgt=df_tgt,
        key_col=key_col,
        attr_cols=attr_cols,
        extra_cols=extra_cols,
        is_scd2=is_scd2,
        batch_id=batch_id
    )
    print(f"df_final.columns : {df_final.columns}")

    # Save to parquet (local gold layer)
    df_final.write_parquet(gold_parquet / file_path)
    
    # Save to SQL Server (gold schema in DB)
    table_name = f"{schema_name}.{tgt_tbl}"
    df_final.write_database(
        table_name=table_name,
        connection=engine,
        if_table_exists="replace",
    )
    print(f"Saved {tgt_tbl} into DB schema {schema_name}")

    return df_final


In [17]:
## --------------------------------------------------- ##
def build_and_merge_all(table_groups: dict) -> dict[str, pl.DataFrame]:
    results = {}

    # --- INITIAL group ---
    for src_table, tgt_tbls in table_groups.get("initial", {}).items():
        src_file = staging_parquet / f"staging_{src_table}.parquet"
        df_src = pl.read_parquet(src_file)

        for tgt_tbl in tgt_tbls:
            print(f"\n==============================================================================\n")
            print(f"WORKING FOR INITIAL TABLE: {tgt_tbl}")
            results[tgt_tbl] = run_merge(df_src, tgt_tbl, engine, schema_name='gold')

    # --- Final group ---
    df_src_map = {}

    for tgt_tbl in table_groups.get("final", []):
        print(f"\n==============================================================================\n")
        print(f"WORKING FOR FINAL TABLE: {tgt_tbl}")

        df_src_map[tgt_tbl] = build_source(tgt_tbl, staging_parquet, gold_parquet)
        df_src = df_src_map[tgt_tbl]

        results[tgt_tbl] = run_merge(df_src, tgt_tbl, engine, schema_name='gold')
        
    return results
