In [0]:
%sql
use catalog corerep;
create schema if not exists transform_layer;

In [0]:
from typing import Dict, List, Tuple, Optional
from pyspark.sql import DataFrame
from pyspark.sql import functions as F
from pyspark.sql.column import Column
from pyspark.sql.window import Window


def transform_to_silver_gl(
    df_in: DataFrame,
    source_system: str,
    *,
    # 1) Select + cast + rename (schema mapping)
    select_exprs: List[Column],

    # 2) DQ filters
    not_null_cols: List[str],

    # 3) Group and aggregate
    group_by_cols: List[str],
    sum_cols: List[str],  # columns to sum
    net_movement_expr: Column,  # expression that creates Net_Movement

    # 4) De-duplication (deterministic)
    dedup_key: List[str],
    dedup_order: List[Tuple[str, str]],  # e.g. [("ParentSystemId","desc")]

    # 5) Closing balance window
    closing_partition_cols: List[str],
    closing_order_col: str,  # e.g. "PeriodKey"

    # Optional: add more calculated cols after aggregation
    post_calc_cols: Optional[Dict[str, Column]] = None,
) -> DataFrame:

    # Step 0: audit
    df = (df_in
          .withColumn("_ingest_ts", F.current_timestamp())
          .withColumn("_source_system", F.lit(source_system)))

    # Step 1: select/cast/rename
    df = df.select(*select_exprs)

    # Step 2: DQ filters
    cond = None
    for c in not_null_cols:
        if c not in df.columns:
            raise ValueError(f"DQ column '{c}' not found. Available: {df.columns}")
        cnd = F.col(c).isNotNull()
        cond = cnd if cond is None else (cond & cnd)
    df = df.filter(cond)

    # Step 3: aggregate
    for c in group_by_cols:
        if c not in df.columns:
            raise ValueError(f"group_by column '{c}' not found. Available: {df.columns}")
    for c in sum_cols:
        if c not in df.columns:
            raise ValueError(f"sum column '{c}' not found. Available: {df.columns}")

    agg_exprs = [F.sum(c).alias(c) for c in sum_cols]
    df = df.groupBy(*group_by_cols).agg(*agg_exprs)

    # Add Net_Movement
    df = df.withColumn("Net_Movement", net_movement_expr)

    # Optional extra derived cols after aggregation (e.g., PeriodKey)
    if post_calc_cols:
        for name, expr in post_calc_cols.items():
            df = df.withColumn(name, expr)

    # Step 4: dedupe (keep latest)
    for k in dedup_key:
        if k not in df.columns:
            raise ValueError(f"dedup key '{k}' not found. Available: {df.columns}")

    order_exprs = []
    for col_name, direction in dedup_order:
        if col_name not in df.columns:
            raise ValueError(f"dedup order col '{col_name}' not found. Available: {df.columns}")
        if direction.lower() == "desc":
            order_exprs.append(F.col(col_name).desc())
        elif direction.lower() == "asc":
            order_exprs.append(F.col(col_name).asc())
        else:
            raise ValueError(f"Invalid direction '{direction}' for '{col_name}'. Use 'asc' or 'desc'.")

    w = Window.partitionBy(*dedup_key).orderBy(*order_exprs)
    df = (df.withColumn("_rn", F.row_number().over(w))
            .filter(F.col("_rn") == 1)
            .drop("_rn"))

    # Step 5: Closing balance (running sum by closing_order_col desc)
    for c in closing_partition_cols:
        if c not in df.columns:
            raise ValueError(f"closing partition col '{c}' not found. Available: {df.columns}")
    if closing_order_col not in df.columns:
        raise ValueError(f"closing order col '{closing_order_col}' not found. Available: {df.columns}")

    w2 = (Window.partitionBy(*closing_partition_cols)
                .orderBy(F.col(closing_order_col).desc())
                .rowsBetween(Window.unboundedPreceding, Window.currentRow))

    df = df.withColumn("Closing_Balance", F.sum("Net_Movement").over(w2))

    return df


In [0]:
df_transform_gl_account = spark.table("corerep.raw_data.raw_gl_balances")

select_exprs = [
    F.col("ParentSystemId").cast("bigint").alias("ParentSystemId"),
    F.col("GLBalanceKey").cast("string").alias("GLBalanceKey"),
    F.col("LedgerId").alias("LedgerId"),
    F.col("TransactionCurrency").alias("TransactionCurrency"),
    F.col("LedgerCurrency").cast("string").alias("LedgerCurrency"),
    F.col("PeriodYear").alias("PeriodYear"),
    F.col("PeriodNum").alias("PeriodNum"),
    F.col("TransCurrPeriodNetDr"),
    F.col("TransCurrPeriodNetCr"),
]

df_general_ledger = transform_to_silver_gl(
    df_in=df_transform_gl_account,
    source_system="raw_gl_balances",

    # Step 1
    select_exprs=select_exprs,

    # Step 2
    not_null_cols=["ParentSystemId", "GLBalanceKey"],

    # Step 3
    group_by_cols=[
        "ParentSystemId", "GLBalanceKey", "LedgerId",
        "TransactionCurrency", "LedgerCurrency", "PeriodYear", "PeriodNum"
    ],
    sum_cols=["TransCurrPeriodNetDr", "TransCurrPeriodNetCr"],
    net_movement_expr=F.col("TransCurrPeriodNetDr") - F.col("TransCurrPeriodNetCr"),

    # PeriodKey after aggregation (same as your example)
    post_calc_cols={"PeriodKey": F.col("PeriodYear") * 100 + F.col("PeriodNum")},

    # Step 4
    dedup_key=["GLBalanceKey"],
    dedup_order=[("ParentSystemId", "desc")],

    # Step 5
    closing_partition_cols=[
        "ParentSystemId", "GLBalanceKey", "LedgerId",
        "TransactionCurrency", "LedgerCurrency", "PeriodYear"
    ],
    closing_order_col="PeriodKey",
)

df_general_ledger.display()
