# --------------------------------------------------------------------------------
# DQ MONITOR — Core bootstrap, config, and thresholds
# --------------------------------------------------------------------------------
"""
DQ Monitoring — quick start

What this script does (at a glance)
-----------------------------------
- Loads Spark once per run and reuses cached tables during checks.
- Writes both a consolidated monitoring report and a per-metric history table.
- Uses recent metric history as fallback reference for PSI/CSI and drift baselines.
- Captures internal exceptions into a dedicated dataset for auditability.

Key outputs (Dataiku datasets)
------------------------------
- DQ_MONITORING_REPORT   : One row per metric/check evaluated in this run.
- DQ_METRIC_HISTORY      : Long-term, per-metric event store (reference source for drift).
- DQ_INTERNAL_ERRORS     : Fail-safe sink for internal exceptions (not data issues).

How to run
----------
1) Ensure project flow variables include DKU_DST_date (YYYY-MM-DD). If absent, "today" is used.
2) Provide region via project variables: region or country_region (defaults to 'hk').
3) Run in Dataiku as a recipe or from a notebook with access to the same variables.

Important knobs (tune as needed)
--------------------------------
- DRIFT_BINS                     : Histogram bins for PSI/CSI.
- DRIFT_SAMPLE_SIZE              : Max sample size per partition for drift checks.
- MAX_REF_AGE_DAYS               : Max age for reference partitions.
- MAX_CATEGORICAL_CARDINALITY    : Guardrail for high-cardinality categorical features.
- dq_static_thresholds           : Single source of truth for alert thresholds.

Reliability notes
-----------------
- Empty-DF-safe operations (checks won’t explode on empty inputs).
- No duplicate “raw metrics” — sanitization is applied before persisting.
- Verbose logging with run_id and run_ts to tie all artifacts to a single execution.
"""

In [0]:
from __future__ import annotations

"""
DQ monitoring script — canonical metrics_history, per-metric outputs derived.
"""

# --- Standard library ---
import ast
import copy
import json
import logging
import math
import uuid
from datetime import date, datetime, timedelta
from typing import Any, Dict, List, Optional, Tuple, Union

# --- Dataiku and third-party ---
import dataiku
from dataiku import spark as dkuspark

from pyspark.sql import SparkSession, DataFrame
from pyspark.sql import functions as F
from pyspark.sql import types as T
from pyspark.ml.feature import Bucketizer

# Pull in commonly used types so bare names (StructType, StringType, …) work
from pyspark.sql.types import (
 StructType, StructField, StringType, DoubleType, TimestampType, NumericType, DateType
)

# Handy function aliases where code uses F_sum / F_countDistinct, etc.
F_sum = F.sum
F_countDistinct = F.countDistinct
F_min = F.min
F_max = F.max
F_when = F.when

# --- Internal ---
from scmac.spark import start_spark_session
from stables.data_sources.utils import get_ods_partition

import json
import math
import uuid
import logging
from datetime import datetime, date, timedelta
from typing import Any, Dict, List, Optional, Tuple, Union

import dataiku
from dataiku import spark as dkuspark
from pyspark.sql import SparkSession, DataFrame
from pyspark.sql import functions as F
from pyspark.sql.utils import AnalysisException

In [0]:
# -----------------------
# Constants / config
# -----------------------

DEFAULT_REGION = "hk" # Default region when not explicitly provided
DRIFT_BINS = 10 # Number of bins for numeric drift (PSI/Wasserstein bucketization)
EPSILON = 1e-9 # Small constant to avoid divide-by-zero/log(0) issues
DRIFT_SAMPLE_SIZE = 15000 # Max rows sampled per partition for drift checks
MAX_LOOKBACK_SAMPLES = 20 # Max historical partitions used for fallback PSI/CSI
MAX_REF_AGE_DAYS = 10 # Reference partition must be within N days of current
MAX_CATEGORICAL_CARDINALITY = 5000 # Upper cap on categories considered for CSI/JS diagnostics

In [0]:
# -----------------------
# Logger & Spark session
# -----------------------
logger = logging.getLogger("dq-monitor")
if not logger.handlers:
    h = logging.StreamHandler()
    h.setFormatter(logging.Formatter("%(asctime)s | %(levelname)s | %(message)s"))
    logger.addHandler(h)
logger.setLevel(logging.INFO)

#Spark Session
spark: SparkSession
if start_spark_session:
    spark, sql_context = start_spark_session(
        "Mules_DQ_MONITOR",
        spark_config={
            "spark.executor.memory": "16g",
            "spark.driver.memory": "16g",
            "spark.sql.adaptive.enabled": "true",
            "spark.sql.adaptive.coalescePartitions.enabled": "true",
            "spark.sql.adaptive.shuffle.targetPostShuffleInputSize": "128m",
            "spark.sql.broadcastTimeout": "1200",
            "spark.dynamicAllocation.enabled": "true",
            "spark.dynamicAllocation.initialExecutors": "2",
            "spark.dynamicAllocation.minExecutors": "1",
            "spark.dynamicAllocation.maxExecutors": "10",
            "spark.sql.ansi.enabled": "false"
        }
    )
else:
    raise ImportError("start_spark_session not available")

In [0]:
# --- runtime / env ---
try:
    RUN_ID = str(uuid.uuid4())
    RUN_TS = datetime.now().isoformat()
    CURR_DATE_STR = dataiku.dku_flow_variables.get("DKU_DST_date", datetime.now().strftime("%Y-%m-%d"))
    CURR_DATE = datetime.strptime(CURR_DATE_STR, "%Y-%m-%d").date()
    REGION = dataiku.dku_flow_variables.get("region", "sg")
except Exception:
    logger.warning("Flow vars missing; using defaults.")
    RUN_ID = str(uuid.uuid4())
    RUN_TS = datetime.now().isoformat()
    CURR_DATE_STR = datetime.now().strftime("%Y-%m-%d")
    CURR_DATE = datetime.now().date()
    REGION = "hk"

try:
    _custom_vars = dataiku.get_custom_variables() or {}
except Exception as e:
    logger.warning(f"Could not retrieve custom variables: {e}. Using default values.")
    _custom_vars = {}

REGION = (_custom_vars.get('country_region') or _custom_vars.get('region') or DEFAULT_REGION).lower()
PROTEGRITY_UDF = _custom_vars.get('PROTEGRITY_UDF', 'protegrity.ptyProtectStr')
PROTEGRITY_POLICY = _custom_vars.get('PROTEGRITY_POLICY', 'TE_A_N_L0R0_S23_Y_AST')
RUN_TS = datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%SZ")
RUN_ID = str(uuid.uuid4())


PER_METRIC_COLUMNS_MAP: Dict[str, List[str]] = {
    "default": ["run_id", "run_ts", "partition_date", "metric_type", "source_system", "table_name", "column_name", "metric_value", "metric_value_num", "threshold", "reference_value", "status", "alert_flag", "country"],
    "row_count": ["run_id", "run_ts", "partition_date", "metric_type", "source_system", "table_name", "metric_value", "metric_value_num", "reference_value", "status", "country"],
    "psi_buckets": ["run_id", "run_ts", "partition_date", "metric_type", "source_system", "table_name", "column_name", "metric_value", "status", "country"],
    "csi_buckets": ["run_id", "run_ts", "partition_date", "metric_type", "source_system", "table_name", "column_name", "metric_value", "status", "country"],
}

OUTPUT_DATASETS: Dict[str, str] = {
"metrics_history_df": "DQ_METRIC_HISTORY",
"monitoring_report_df": "DQ_MONITORING_REPORT",
"internal_errors": "DQ_INTERNAL_ERRORS",
}


# -----------------------
# Thresholds
# -----------------------

dq_static_thresholds = {
 # --- core DQ ---
 "row_count_drift_pct": 0.01, # % deviation vs reference allowed for rowcount
 "completeness_pct": 99.0, # required non-null %
 "uniqueness_pct": 99.9, # required unique %
 "join_consistency_pct": 95.0, # min % rows that join successfully
 "cross_system_key_consistency_pct": 90.0,
 "max_out_of_range_pct": 0.0, # % values allowed outside declared ranges
 "date_parse_min_pct": 95.0, # % date parse success required
 "latency_days": 1, # freshness SLA
 "max_ref_partition_age_days": MAX_REF_AGE_DAYS, # cap on historical ref age
 "historical_partition_window": 30, # lookback window for dynamic rowcount calc (days)

 # --- drift: shared ---
 "drift_sample_size": DRIFT_SAMPLE_SIZE, # rows sampled for drift calcs
 "drift_bins": DRIFT_BINS, # numeric bin count for PSI/Wasserstein buckets
 "max_categorical_cardinality": MAX_CATEGORICAL_CARDINALITY, # cap for categorical diagnostics

 # --- numeric drift ---
 "numeric_psi_max": 0.20, # PSI diagnostic threshold (informational)
 "median_shift_pct": 0.10, # relative median shift trigger (|Δmedian| / |median_ref|)
 "wasserstein_max": 1.0, # ws drift threshold

 # --- categorical drift ---
 "categorical_csi_max": 50.0, # Chi-square diagnostic threshold (informational)
 "js_divergence_max": 1.15, # JS Shannon drift threshold (explicit)
 "topk_mass_delta": 0.05, # abs change in top-K mass to trigger drift
 "topk_churn": 0.30, # churn ratio in top-K categories to trigger drift
}


#system partition policy
PARTITION_COL_BY_SYSTEM = {
    "TMX": "edmp_insert_date",      # yyyy-MM-dd
    "ICM": "ods",                   # yyyy-MM-dd
    "IBANKING": "ods",              # yyyy-MM-dd
    "EBBS": "ods",                  # yyyy-MM-dd  (feature dates like 'valuedt' are NOT the partition key)
    "HOGAN": "process_date",        # yyyyMMdd for vw_casa_acct; most others 'ods' (yyyy-MM-dd). Table overrides win.
}



# How many days of daily-level partitions are retained by default
DEFAULT_DAILY_RETENTION_DAYS = 60

# System-level overrides (ICM is 45 as per your rule)
DAILY_RETENTION_BY_SYSTEM = {
    "ICM": 45,   # month-end only beyond 45 days
    # add others if needed, else they inherit table/system or default
}
# we can still override per table with tbl_conf["daily_retention_days"], or per system with sys_conf["daily_retention_days"].


# -----------------------------
# Cross-system key harmonization (extracted)
# -----------------------------
cross_system_key_harmonization = [
 {
 "description": "Mapping TMX account_login to IBANKING tkn_relationshipno via cust_ebid",
 "source_system_from": "TMX",
 "table_from": "vw_tmx_all_vrfy_usr_rspons",
 "key_from": "account_login",
 "source_system_to": "IBANKING",
 "table_to": "actv_tbl_cust",
 "key_to": "tkn_relationshipno",
 "common_linking_key_in_from": "account_login",
 "common_linking_key_in_to": "cust_ebid",
 "relationship_type": "account_to_token"
 }
]

# -----------------------
# Schemas
# -----------------------
metrics_history_schema = StructType([
    T.StructField("run_id", T.StringType(), False),
    T.StructField("run_ts", T.StringType(), False),
    T.StructField("partition_date", T.StringType(), False),
    T.StructField("metric_type", T.StringType(), False),
    T.StructField("source_system", T.StringType(), True),
    T.StructField("table_name", T.StringType(), True),
    T.StructField("column_name", T.StringType(), True),
    T.StructField("metric_value", T.StringType(), True),
    T.StructField("metric_value_num", T.DoubleType(), True),
    T.StructField("threshold", T.StringType(), True),
    T.StructField("reference_value", T.DoubleType(), True),
    T.StructField("status", T.StringType(), False),
    T.StructField("alert_flag", T.StringType(), True),
    T.StructField("country", T.StringType(), True),
])

monitoring_report_schema = StructType([
    T.StructField("run_id", T.StringType(), False),
    T.StructField("run_ts", T.StringType(), False),
    T.StructField("partition_date", T.StringType(), False),
    T.StructField("metric_type", T.StringType(), False),
    T.StructField("source_system", T.StringType(), True),
    T.StructField("table_name", T.StringType(), True),
    T.StructField("column_name", T.StringType(), True),
    T.StructField("alert_flag", T.StringType(), True),
    T.StructField("metric_value", T.StringType(), True),
    T.StructField("metric_value_num", T.DoubleType(), True),
    T.StructField("threshold", T.StringType(), True),
    T.StructField("reference_value",T.DoubleType(), True),
    T.StructField("status", T.StringType(), False),
    T.StructField("country", T.StringType(), True),
])

In [0]:
# -----------------------
# Table Configurations
# -----------------------

def validate_and_normalize_table_config(raw_conf: Dict[str, Dict[str, Any]], region: str) -> Dict[str, Dict[str, Any]]:
    """
    Normalize the raw table configuration:
      - convert legacy 'columns' lists into per-column metadata dicts
      - infer roles from hints (numeric_drift_cols, date_columns, join_key)
      - populate preprocess/cast_to defaults for numeric/date features
      - set per-table country/region defaults based on source_system and provided region
    """
    out: Dict[str, Dict[str, Any]] = {}
    reg_norm = (region or "hk").strip().lower()

    for sys_name, sys_conf in (raw_conf or {}).items():
        sys_copy = copy.deepcopy(sys_conf)  # do not mutate input
        sys_copy.setdefault("default_partition_col", sys_conf.get("default_partition_col", "ods"))

        tables = sys_copy.get("tables", {})
        normalized_tables: Dict[str, Dict[str, Any]] = {}

        for tbl_name, tbl_conf in tables.items():
            tc = copy.deepcopy(tbl_conf)  # table copy

            if not tc.get("full_table_path"):
                schema = tc.get("schema")
                if schema:
                    tc["full_table_path"] = f"{schema}.{tbl_name}"
                else:
                    tc["full_table_path"] = tbl_name
                    logger.warning(f"{sys_name}.{tbl_name} missing full_table_path and schema; using table name '{tbl_name}'")

            cols = tc.get("columns") or []
            col_map: Dict[str, Dict[str, Any]] = {}

            if isinstance(cols, list):
                for c in cols:
                    if isinstance(c, str):
                        col_map[c] = {
                            "role": None, "data_type": None, "cast_to": None, "preprocess": {},
                            "sensitive": False, "tokenize": False, "tokenize_policy": None,
                            "numerical_range": None, "thresholds": {}, "cardinality_hint": None,
                            "skip_checks": []
                        }
                    elif isinstance(c, dict) and c.get("name"):
                        name = c["name"]
                        col_map[name] = {
                            "role": c.get("role"), "data_type": c.get("data_type"), "cast_to": c.get("cast_to"),
                            "preprocess": c.get("preprocess", {}), "sensitive": c.get("sensitive", False),
                            "tokenize": c.get("tokenize", False), "tokenize_policy": c.get("tokenize_policy"),
                            "numerical_range": c.get("numerical_range", None) or tc.get("numerical_ranges", {}).get(name),
                            "thresholds": c.get("thresholds", {}), "cardinality_hint": c.get("cardinality_hint"),
                            "skip_checks": c.get("skip_checks", []),
                        }
                    else:
                        logger.warning(f"Unrecognized column entry for {sys_name}.{tbl_name}: {c}")
            elif isinstance(cols, dict):
                for name, meta in cols.items():
                    col_map[name] = {
                        "role": meta.get("role"), "data_type": meta.get("data_type"), "cast_to": meta.get("cast_to"),
                        "preprocess": meta.get("preprocess", {}), "sensitive": meta.get("sensitive", False),
                        "tokenize": meta.get("tokenize", False), "tokenize_policy": meta.get("tokenize_policy"),
                        "numerical_range": meta.get("numerical_range", None) or tc.get("numerical_ranges", {}).get(name),
                        "thresholds": meta.get("thresholds", {}), "cardinality_hint": meta.get("cardinality_hint"),
                        "skip_checks": meta.get("skip_checks", [])
                    }

            for nm in (tc.get("numeric_drift_cols") or []):
                if nm not in col_map:
                    col_map[nm] = {
                        "role": "numeric_feature", "data_type": None, "cast_to": "double",
                        "preprocess": {"strip_commas": True, "trim": True}, "sensitive": False,
                        "tokenize": False, "tokenize_policy": None,
                        "numerical_range": tc.get("numerical_ranges", {}).get(nm), "thresholds": {},
                        "cardinality_hint": None, "skip_checks": []
                    }
                else:
                    col_map[nm]["role"] = col_map[nm].get("role") or "numeric_feature"
                    col_map[nm].setdefault("cast_to", "double")
                    col_map[nm].setdefault("preprocess", {"strip_commas": True, "trim": True})

            for nm, fmt in (tc.get("date_columns") or {}).items():
                if nm not in col_map:
                    col_map[nm] = {
                        "role": "timestamp", "data_type": "timestamp", "cast_to": "timestamp",
                        "preprocess": {}, "sensitive": False, "tokenize": False, "tokenize_policy": None,
                        "numerical_range": None, "thresholds": {}, "cardinality_hint": None,
                        "skip_checks": []
                    }
                else:
                    col_map[nm]["role"] = col_map[nm].get("role") or "timestamp"
                    col_map[nm].setdefault("cast_to", "timestamp")

            jk = tc.get("join_key")
            if jk:
                if jk not in col_map:
                    col_map[jk] = {
                        "role": "primary_key", "data_type": None, "cast_to": None,
                        "preprocess": {}, "sensitive": False, "tokenize": False, "tokenize_policy": None,
                        "numerical_range": None, "thresholds": {}, "cardinality_hint": None,
                        "skip_checks": []
                    }
                else:
                    col_map[jk]["role"] = col_map[jk].get("role") or "primary_key"

            for candidate in (tc.get("tokenize_candidates") or []):
                if candidate in col_map:
                    col_map[candidate]["tokenize"] = True
                    col_map[candidate]["sensitive"] = True
                    col_map[candidate].setdefault("tokenize_policy", PROTEGRITY_POLICY)

            for coln, meta in list(col_map.items()):
                if not meta.get("role"):
                    lname = coln.lower()
                    if lname.endswith("_id") or "id" in lname or lname.endswith("no") or lname.endswith("num"):
                        meta["role"] = "identifier"
                    elif any(x in lname for x in ["date", "dt", "timestamp", "ts"]):
                        meta["role"] = "timestamp"
                    elif any(x in lname for x in ["amt", "amount", "bal", "score", "limit", "sum","rate"]):
                        meta["role"] = "numeric_feature"
                        meta.setdefault("cast_to", "double")
                        meta.setdefault("preprocess", {"strip_commas": True, "trim": True})
                    elif any(x in lname for x in ["code", "type", "category", "flag", "status", "ccy", "country", "currency"]):
                        meta["role"] = "categorical_feature"
                    else:
                        meta["role"] = "categorical_feature"

                if meta.get("cardinality_hint") in (None, "auto"):
                    meta["cardinality_hint"] = meta.get("cardinality_hint")

                meta["skip_checks"] = meta.get("skip_checks") or []
                meta["thresholds"] = meta.get("thresholds") or {}

                if not meta.get("numerical_range"):
                    meta["numerical_range"] = tc.get("numerical_ranges", {}).get(coln)

                col_map[coln] = meta

            tc["columns"] = col_map

            tc.setdefault("feature_thresholds", tc.get("feature_thresholds", {}))
            tc.setdefault("tokenize_candidates", tc.get("tokenize_candidates", []))

            src = (tc.get("source_system") or sys_name).strip().upper()
            if src == "HOGAN":
                tc.setdefault("country", "hk")
                tc.setdefault("region", "hk")
            elif src == "EBBS":
                tc.setdefault("country", "sg")
                tc.setdefault("region", "sg")
            elif src in ("ICM", "TMX", "IBANKING"):
                tc.setdefault("country", reg_norm)
                tc.setdefault("region", reg_norm)
                tc.setdefault("region_hint", "hk/sg")
            else:
                tc.setdefault("country", reg_norm)
                tc.setdefault("region", reg_norm)
                tc.setdefault("region_hint", "sg/hk")

            normalized_tables[tbl_name] = tc

        sys_copy["tables"] = normalized_tables
        out[sys_name] = sys_copy

    return out

In [0]:
def get_table_config(region: Optional[str]) -> Dict[str, Dict[str, Any]]:
    """
    Build raw table configuration and then normalize it with validate_and_normalize_table_config.
    """
    region_code = (region or "hk").strip().lower()

    hog_schema = f"prd_vw_sri_hog_{region_code}_sen"
    hog_adm_schema = f"prd_vw_adm_{region_code}_nsen"
    ebbs_schema = f"prd_vw_sri_ebbs_{region_code}_sen"
    icm_schema = f"prd_vw_sri_icm_{region_code}_sen"
    ibanking_schema = f"prd_vw_sri_ibanking_{region_code}_tkn"
    tmx_schema = f"prd_vw_fdl_tmx_{region_code}_sen" if region_code == "sg" else f"prd_vw_sri_tmx_{region_code}_sen"

    raw = {
    "HOGAN": {
        "default_partition_col": "ods",
        "tables": {
            "actv_cst_data": {
                "full_table_path": f"{hog_schema}.actv_cst_data",
                "partition_col": "ods",
                "join_key": "cust_key",
                "columns": [
                    {"name": "cust_key", "role": "identifier", "skip_checks": ["csi"]},
                    {"name": "cust_id_type", "role": "categorical_feature"},
                    # Treat ID as sensitive + exclude from CSI drift
                    {"name": "cust_id_no", "role": "identifier", "tokenize": True, "tokenize_policy": PROTEGRITY_POLICY, "skip_checks": ["csi"]},
                    {"name": "cust_csi", "role": "categorical_feature"},
                ],
                "date_columns": {},
            },
            "actv_amt_tda_acct": {
                "full_table_path": f"{hog_schema}.actv_amt_tda_acct",
                "partition_col": "ods",
                "join_key": "acct_num",
                "timestamp_col": "act_opend_dt",
                "columns": [
                    {"name": "prmry_cust_key", "role": "identifier", "skip_checks": ["csi"]},
                    {"name": "acct_num", "role": "identifier", "skip_checks": ["csi"]},
                    {"name": "act_opend_dt", "role": "timestamp", "cast_to": "timestamp"},
                    {"name": "act_clsed_dt", "role": "timestamp", "cast_to": "timestamp"},
                    {"name": "act_st", "role": "categorical_feature"},
                    {"name": "act_prod", "role": "categorical_feature"},
                    {"name": "act_to_cust_rel", "role": "categorical_feature"},
                ],
                "date_columns": {"act_opend_dt": "yyyy-MM-dd", "act_clsed_dt": "yyyy-MM-dd"},
            },
            "actv_gaml_txn_cb": {
                "full_table_path": f"{hog_schema}.actv_gaml_txn_cb",
                "partition_col": "ods",
                "timestamp_col": "d_eff_dt",
                "columns": [
                    {"name": "d_acct_num", "role": "identifier", "skip_checks": ["csi"]},
                    {"name": "d_tran_ccy", "role": "categorical_feature"},
                    {"name": "d_eff_dt", "role": "timestamp", "cast_to": "timestamp"},
                    {"name": "d_tran_type", "role": "categorical_feature"},
                    {"name": "d_tran_amt", "role": "numeric_feature", "data_type": "decimal", "cast_to": "double",
                     "preprocess": {"strip_commas": True, "trim": True}, "thresholds": {"wasserstein": 0.05}},
                    {"name": "d_tran_code", "role": "categorical_feature"},
                    {"name": "d_tran_desc", "role": "categorical_feature"},
                    {"name": "d_matched_rev", "role": "categorical_feature"},
                ],
                "date_columns": {"d_eff_dt": "yyyy-MM-dd"},
                "numeric_drift_cols": ["d_tran_amt"],
            },
            "actv_tran_code": {
                "full_table_path": f"{hog_schema}.actv_tran_code",
                "partition_col": "ods",
                "join_key": "txn_code",
                "columns": [
                    {"name": "txn_code", "role": "categorical_feature"},
                    {"name": "desc1", "role": "categorical_feature"},
                ],
                "date_columns": {},
            },
            "vw_casa_acct": {
                "full_table_path": f"{hog_adm_schema}.vw_casa_acct",
                "partition_col": "process_date",
                "join_key": "n_acct",
                "timestamp_col": "process_date",
                "columns": [
                    {"name": "n_acct", "role": "identifier", "skip_checks": ["csi"]},
                    {"name": "c_acctccy", "role": "categorical_feature"},
                    {"name": "a_curbal_acctccy", "role": "numeric_feature", "cast_to": "double", "preprocess": {"strip_commas": True}},
                    {"name": "process_date", "role": "timestamp", "cast_to": "timestamp"},
                ],
                "date_columns": {"process_date": "yyyyMMdd"},
            },
            "actv_grouprates": {
                "full_table_path": "prd_vw_sri_act_grp_tkn.actv_grouprates",
                "partition_col": "ods",
                "join_key": "sc_cury",
                "columns": [
                    {"name": "sc_cury", "role": "categorical_feature"},
                    {"name": "sc_cury_2", "role": "categorical_feature"},
                    {"name": "sc_cash", "role": "numeric_feature", "cast_to": "double", "preprocess": {"strip_commas": True}},
                    {"name": "ods", "role": "timestamp"},
                ],
                "numeric_drift_cols": ["sc_cash"],
            },
        },
        "tables_to_join_for_consistency": [
            {"left_table": "actv_amt_tda_acct", "right_table": "actv_cst_data", "join_keys": ["prmry_cust_key"]},
            {"left_table": "actv_gaml_txn_cb", "right_table": "actv_tran_code", "join_keys": ["d_tran_code"]},
        ],
    },

    "EBBS": {
        "default_partition_col": "ods",
        "tables": {
            "actv_mastrel": {
                "full_table_path": f"{ebbs_schema}.actv_mastrel",
                "partition_col": "ods",
                "join_key": "masterno",
                "columns": [
                    {"name": "masterno", "role": "identifier", "tokenize": True, "tokenize_policy": PROTEGRITY_POLICY, "skip_checks": ["csi"]},
                    {"name": "relationshipno", "role": "identifier", "tokenize": True, "tokenize_policy": PROTEGRITY_POLICY, "skip_checks": ["csi"]},
                    {"name": "custsegmtcode", "role": "categorical_feature"},
                    {"name": "primaryflag", "role": "categorical_feature"},
                ],
                "date_columns": {},
            },
            "actv_rel": {
                "full_table_path": f"{ebbs_schema}.actv_rel",
                "partition_col": "ods",
                "join_key": "relationshipno",
                "columns": [
                    {"name": "relationshipno", "role": "identifier", "tokenize": True, "tokenize_policy": PROTEGRITY_POLICY, "skip_checks": ["csi"]},
                    {"name": "custsegmtcode", "role": "categorical_feature"},
                ],
                "date_columns": {},
            },
            "actv_account": {
                "full_table_path": f"{ebbs_schema}.actv_account",
                "partition_col": "ods",
                "join_key": "accountno",
                "columns": [
                    {"name": "masterno", "role": "identifier", "tokenize": True, "tokenize_policy": PROTEGRITY_POLICY, "skip_checks": ["csi"]},
                    {"name": "accountno", "role": "identifier", "tokenize": True, "tokenize_policy": PROTEGRITY_POLICY, "skip_checks": ["csi"]},
                    {"name": "acopendate", "role": "timestamp", "cast_to": "timestamp"},
                    {"name": "acctcurrentstatus", "role": "categorical_feature"},
                    {"name": "productcode", "role": "categorical_feature"},
                    {"name": "acclosedt", "role": "timestamp", "cast_to": "timestamp"},
                ],
                "date_columns": {"acopendate": "yyyy-MM-dd", "acclosedt": "yyyy-MM-dd"},
            },
            "actv_currrt": {
                "full_table_path": f"{ebbs_schema}.actv_currrt",
                "partition_col": "ods",
                "join_key": "currencycode",
                "columns": [
                    {"name": "currencycode", "role": "categorical_feature"},
                    {"name": "midrate", "role": "numeric_feature", "cast_to": "double"},
                ],
                "numeric_drift_cols": ["midrate"],
                "date_columns": {},
            },
            "actv_bebal": {
                "full_table_path": f"{ebbs_schema}.actv_bebal",
                "partition_col": "ods",
                "join_key": "accountno",
                "columns": [
                    {"name": "accountno", "role": "identifier", "tokenize": True, "tokenize_policy": PROTEGRITY_POLICY, "skip_checks": ["csi"]},
                    {"name": "currencycode", "role": "categorical_feature"},
                    {"name": "openingbalance", "role": "numeric_feature", "cast_to": "double", "preprocess": {"strip_commas": True}},
                    {"name": "date1", "role": "timestamp", "cast_to": "timestamp"},
                ],
                "date_columns": {"date1": "yyyy-MM-dd"},
                "numeric_drift_cols": ["openingbalance"],
            },
            "actv_trnarc": {
                "full_table_path": f"{ebbs_schema}.actv_trnarc",
                "partition_col": "ods",
                "join_key": "trncode",
                "columns": [
                    {"name": "accountno", "role": "identifier", "tokenize": True, "tokenize_policy": PROTEGRITY_POLICY, "skip_checks": ["csi"]},
                    {"name": "currencycode", "role": "categorical_feature"},
                    {"name": "valuedt", "role": "timestamp", "cast_to": "timestamp"},
                    {"name": "creditdebit", "role": "categorical_feature"},
                    {"name": "channelid", "role": "categorical_feature"},
                    {"name": "transactionamount", "role": "numeric_feature", "cast_to": "double", "preprocess": {"strip_commas": True}},
                    {"name": "trncode", "role": "categorical_feature"},
                    {"name": "transtypecode", "role": "categorical_feature"},
                    {"name": "NARRATION1", "role": "categorical_feature"},
                    {"name": "reversalflag", "role": "categorical_feature"},
                    {"name": "financialtrnflag", "role": "categorical_feature"},
                ],
                "date_columns": {"valuedt": "yyyy-MM-dd"},
                "numeric_drift_cols": ["transactionamount"],
            },
            "actv_trncd": {
                "full_table_path": f"{ebbs_schema}.actv_trncd",
                "partition_col": "ods",
                "join_key": "trncode",
                "columns": [
                    {"name": "trncode", "role": "categorical_feature"},
                    {"name": "description", "role": "categorical_feature"},
                ],
                "date_columns": {},
            },
            "actv_chnl": {
                "full_table_path": f"{ebbs_schema}.actv_chnl",
                "partition_col": "ods",
                "join_key": "channelid",
                "columns": [
                    {"name": "channelid", "role": "categorical_feature"},
                    {"name": "name", "role": "categorical_feature"},
                ],
                "date_columns": {},
            },
        },
        "tables_to_join_for_consistency": [
            {"left_table": "actv_mastrel", "right_table": "actv_account", "join_keys": ["masterno"]},
            {"left_table": "actv_trnarc", "right_table": "actv_trncd", "join_keys": ["trncode"]},
            {"left_table": "actv_trnarc", "right_table": "actv_chnl", "join_keys": ["channelid"]},
        ],
    },

    "ICM": {
        "default_partition_col": "ods",
        "tables": {
            "actv_cust_profiles": {
                "full_table_path": f"{icm_schema}.actv_cust_profiles",
                "partition_col": "ods",
                "join_key": "profile_id",
                "tokenize_candidates": ["relationship_number"],
                "columns": [
                    {"name": "profile_id", "role": "identifier", "skip_checks": ["csi"]},
                    {"name": "relationship_number", "role": "identifier", "tokenize": True, "tokenize_policy": PROTEGRITY_POLICY, "skip_checks": ["csi"]},
                    {"name": "date_of_birth", "role": "timestamp", "cast_to": "timestamp"},
                    {"name": "country_of_birth", "role": "categorical_feature"},
                    {"name": "qualification_code", "role": "categorical_feature"},
                    {"name": "resident_country", "role": "categorical_feature"},
                    {"name": "nationality_code", "role": "categorical_feature"},
                    {"name": "full_name", "role": "categorical_feature"},
                    # partitioned by ods (date); also track creation date as timestamp
                    {"name": "profile_create_date", "role": "timestamp", "cast_to": "timestamp"},
                ],
                "date_columns": {"profile_create_date": "yyyy-MM-dd", "date_of_birth": "yyyy-MM-dd"},
            },
            "actv_cust_risk_indicators": {
                "full_table_path": f"{icm_schema}.actv_cust_risk_indicators",
                "partition_col": "ods",
                "join_key": "profile_id",
                "timestamp_col": "risk_start_date",
                "columns": [
                    {"name": "profile_id", "role": "identifier", "skip_checks": ["csi"]},
                    {"name": "risk_code", "role": "categorical_feature"},
                    {"name": "risk_start_date", "role": "timestamp", "cast_to": "timestamp"},
                ],
                "date_columns": {"risk_start_date": "yyyy-MM-dd"},
            },
            "actv_cust_profile_business": {
                "full_table_path": f"{icm_schema}.actv_cust_profile_business",
                "partition_col": "ods",
                "join_key": "profile_id",
                "timestamp_col": "profile_create_date",
                "columns": [
                    {"name": "profile_id", "role": "identifier", "skip_checks": ["csi"]},
                    {"name": "rel_status", "role": "categorical_feature"},
                    {"name": "profile_create_date", "role": "timestamp", "cast_to": "timestamp"},
                ],
                "date_columns": {"profile_create_date": "yyyy-MM-dd"},
            },
            "actv_cust_product_references": {
                "full_table_path": f"{icm_schema}.actv_cust_product_references",
                "partition_col": "ods",
                "join_key": "profile_id",
                "timestamp_col": "product_open_date",
                "columns": [
                    {"name": "profile_id", "role": "identifier", "skip_checks": ["csi"]},
                    {"name": "product_reference_number", "role": "identifier", "skip_checks": ["csi"]},
                    {"name": "product_code", "role": "categorical_feature"},
                    {"name": "sub_product_code", "role": "categorical_feature"},
                    {"name": "product_reference_status", "role": "categorical_feature"},
                    {"name": "product_contact_type", "role": "categorical_feature"},
                    {"name": "product_open_date", "role": "timestamp", "cast_to": "timestamp"},
                ],
                "date_columns": {"product_open_date": "yyyy-MM-dd"},
            },
            "actv_cust_employments": {
                "full_table_path": f"{icm_schema}.actv_cust_employments",
                "partition_col": "ods",
                "join_key": "profile_id",
                "columns": [
                    {"name": "profile_id", "role": "identifier", "skip_checks": ["csi"]},
                    {"name": "profession_code", "role": "categorical_feature"},
                    {"name": "employer_name", "role": "categorical_feature"},
                    {"name": "own_organisation_name", "role": "categorical_feature"},
                    {"name": "emp_banking_indicator", "role": "categorical_feature"},
                ],
                "date_columns": {},
            },
            "actv_cust_contacts": {
                "full_table_path": f"{icm_schema}.actv_cust_contacts",
                "partition_col": "ods",
                "join_key": "profile_id",
                "columns": [
                    {"name": "profile_id", "role": "identifier", "skip_checks": ["csi"]},
                    {"name": "contact_classification_code", "role": "categorical_feature"},
                    {"name": "contact_country_code", "role": "categorical_feature"},
                    {"name": "contact_area_code", "role": "categorical_feature"},
                    # treat phone-like data as sensitive identifiers; exclude from CSI
                    {"name": "contact_number", "role": "identifier", "tokenize": True, "tokenize_policy": PROTEGRITY_POLICY, "skip_checks": ["csi"]},
                    {"name": "contact_extension", "role": "identifier", "skip_checks": ["csi"]},
                    {"name": "contact_reference", "role": "identifier", "skip_checks": ["csi"]},
                    {"name": "primary_contact", "role": "categorical_feature"},
                    {"name": "s_startdt", "role": "timestamp", "cast_to": "timestamp"},
                ],
                "date_columns": {"s_startdt": "yyyy-MM-dd"},
            },
        },
        "tables_to_join_for_consistency": [
            {"left_table": "actv_cust_profiles", "right_table": "actv_cust_profile_business", "join_keys": ["profile_id"]},
            {"left_table": "actv_cust_profiles", "right_table": "actv_cust_product_references", "join_keys": ["profile_id"]},
            {"left_table": "actv_cust_profiles", "right_table": "actv_cust_risk_indicators", "join_keys": ["profile_id"]},
            {"left_table": "actv_cust_profiles", "right_table": "actv_cust_employments", "join_keys": ["profile_id"]},
            {"left_table": "actv_cust_profiles", "right_table": "actv_cust_contacts", "join_keys": ["profile_id"]},
        ],
    },

    "TMX": {
        "default_partition_col": "edmp_insert_date",
        "tables": {
            "vw_tmx_all_vrfy_usr_rspons": {
                "full_table_path": f"{tmx_schema}.vw_tmx_all_vrfy_usr_rspons",
                "partition_col": "edmp_insert_date",
                "join_key": "request_id",
                "timestamp_col": "edmp_insert_timestamp",
                "columns": [
                    {"name": "request_id", "role": "identifier", "skip_checks": ["csi"]},
                    {"name": "edmp_insert_date", "role": "timestamp", "cast_to": "timestamp"},
                    {"name": "edmp_insert_timestamp", "role": "timestamp", "cast_to": "timestamp"},
                    {"name": "account_login", "role": "identifier", "skip_checks": ["csi"]},
                    {"name": "fuzzy_device_id", "role": "identifier", "skip_checks": ["csi"]},
                    {"name": "fuzzy_device_score", "role": "numeric_feature", "cast_to": "double"},
                    {"name": "agent_type", "role": "categorical_feature"},
                    {"name": "input_ip_address", "role": "identifier", "skip_checks": ["csi"]},
                    {"name": "event_type", "role": "categorical_feature"},
                ],
                "date_columns": {"edmp_insert_date": "yyyy-MM-dd", "edmp_insert_timestamp": "yyyy-MM-dd HH:mm:ss"},
            },
            "vw_tmx_all_vrfy_usr_rspons_ext1": {
                "full_table_path": f"{tmx_schema}.vw_tmx_all_vrfy_usr_rspons_ext1",
                "partition_col": "edmp_insert_date",
                "join_key": "request_id",
                "numerical_ranges": {
                    "summary_risk_score": {"min": 0, "max": 100},
                    "transaction_amount": {"min": 0, "max": 100_000_000}
                },
                "columns": [
                    {"name": "request_id", "role": "identifier", "skip_checks": ["csi"]},
                    {"name": "summary_risk_score", "role": "numeric_feature", "cast_to": "double"},
                    {"name": "transaction_amount_usd", "role": "numeric_feature", "cast_to": "double",
                     "preprocess": {"strip_commas": True}},
                ],
                "numeric_drift_cols": ["transaction_amount_usd"],
                "date_columns": {},
            },
            "vw_tmx_all_vrfy_usr_rspons_ext2": {
                "full_table_path": f"{tmx_schema}.vw_tmx_all_vrfy_usr_rspons_ext2",
                "partition_col": "edmp_insert_date",
                "join_key": "request_id",
                "timestamp_col": "event_datetime",
                "columns": [
                    {"name": "request_id", "role": "identifier", "skip_checks": ["csi"]},
                    {"name": "vpn_score", "role": "numeric_feature", "cast_to": "double"},
                    {"name": "transaction_amount_usd", "role": "numeric_feature", "cast_to": "double",
                     "preprocess": {"strip_commas": True}},
                    {"name": "event_datetime", "role": "timestamp", "cast_to": "timestamp"},
                ],
                "date_columns": {"event_datetime": "yyyy-MM-dd HH:mm:ss"},
                "numeric_drift_cols": ["transaction_amount_usd", "vpn_score"],
            },
            "vw_tmx_all_vrfy_usr_rspons_ext3": {
                "full_table_path": f"{tmx_schema}.vw_tmx_all_vrfy_usr_rspons_ext3",
                "partition_col": "edmp_insert_date",
                "join_key": "request_id",
                "columns": [
                    {"name": "request_id", "role": "identifier", "skip_checks": ["csi"]},
                    {"name": "account_customer_id", "role": "identifier", "skip_checks": ["csi"]},
                ],
                "date_columns": {},
            },
        },
        "tables_to_join_for_consistency": [
            {"left_table": "vw_tmx_all_vrfy_usr_rspons", "right_table": "vw_tmx_all_vrfy_usr_rspons_ext1", "join_keys": ["request_id"]},
            {"left_table": "vw_tmx_all_vrfy_usr_rspons", "right_table": "vw_tmx_all_vrfy_usr_rspons_ext2", "join_keys": ["request_id"]},
            {"left_table": "vw_tmx_all_vrfy_usr_rspons", "right_table": "vw_tmx_all_vrfy_usr_rspons_ext3", "join_keys": ["request_id"]},
        ],
    },

    "IBANKING": {
        "default_partition_col": "ods",
        "tables": {
            "actv_tbl_cust": {
                "full_table_path": f"{ibanking_schema}.actv_tbl_cust",
                "partition_col": "ods",
                "join_key": "cust_ebid",
                "columns": [
                    {"name": "cust_ebid", "role": "identifier", "skip_checks": ["csi"]},
                    {"name": "tkn_relationshipno", "role": "identifier", "tokenize": True,
                     "tokenize_policy": PROTEGRITY_POLICY, "skip_checks": ["csi"]},
                ],
                "date_columns": {},
            }
        },
        "tables_to_join_for_consistency": [],
    },
}


    return validate_and_normalize_table_config(raw_conf=raw, region=region)

In [0]:
# -----------------------
# Helpers
# -----------------------
def df_is_empty(df: Optional[DataFrame]) -> bool:
    """Robust empty check for none or no rows."""
    if df is None:
        return True
    try:
        return df.count() == 0
    except Exception as e:
        logger.error(f"Error checking if DataFrame is empty: {e}")
        # If we can't even count it, assume it's effectively empty or problematic.
        return True

def write_dataset(dataset_name: str, df: Optional[DataFrame], mode: str = "append", **kwargs) -> None:
    """
    Write a DataFrame to a Dataiku dataset.
    - Preferred: mode in {"append","overwrite"}
    - Back-compat: append=True/False (will be mapped to mode)
    """
    if df is None or df_is_empty(df):
        logger.info(f"Skipping write to dataset '{dataset_name}': DataFrame is None or empty.")
        return

    if "append" in kwargs and kwargs["append"] is not None:
        mode = "append" if bool(kwargs["append"]) else "overwrite"

    try:
        ds = dataiku.Dataset(dataset_name)
        dkuspark.write_with_schema(ds, df, options={"mode": mode})
        logger.info(f"Wrote to '{dataset_name}' (mode={mode}).")
    except Exception:
        logger.exception(f"Failed to write Dataiku dataset '{dataset_name}'.")


def _lower_map(cols: List[str]) -> Dict[str, str]:
    """Creates a mapping from lowercase column name to original column name."""
    lm: Dict[str, str] = {}
    for c in cols:
        lc = c.lower()
        if lc in lm:
            raise ValueError(f"Duplicate case-insensitive column: '{lc}' -> {lm[lc]} and {c}")
        lm[lc] = c
    return lm


def resolve_col_name(df: DataFrame, name: Optional[str]) -> Optional[str]:
    """Resolves a column name, ignoring case, using a cached mapping."""
    if not name or df is None:
        return None
    # Ensure lower_map is generated only once per DataFrame or cached.
    # For simplicity here, it's generated each call, assuming schema doesn't change mid-function.
    lm = _lower_map(df.columns)
    return lm.get(name.lower())

def resolve_col_list(df: DataFrame, cols: List[str]) -> List[str]:
    """Resolves a list of column names, ignoring case."""
    if df is None:
        return []
    lm = _lower_map(df.columns)
    resolved: List[str] = []
    for c in cols:
        r = lm.get(c.lower())
        if r:
            resolved.append(r)
        else:
            logger.warning(f"Column '{c}' not found in DataFrame schema.")
    return resolved

def _system_partition_col(sys_name: str, tbl_conf: Dict[str, Any]) -> str:
    # Table override > table default_partition_col > system default > 'ods'
    return (
        tbl_conf.get("partition_col")
        or tbl_conf.get("default_partition_col")
        or PARTITION_COL_BY_SYSTEM.get((sys_name or "").upper(), "ods")
    )


def rows_to_df(spark_sess: SparkSession, rows: List[Dict[str, Any]], schema: Optional[StructType] = None) -> Optional[DataFrame]:
    """Creates a Spark DataFrame from a list of dictionaries."""
    if not rows:
        if schema is not None:
            return spark_sess.createDataFrame([], schema=schema)
        return None
    try:
        if schema is not None:
            return spark_sess.createDataFrame(rows, schema=schema)
        return spark_sess.createDataFrame(rows)
    except Exception as e:
        logger.exception(f"Failed to create DataFrame from rows. Schema provided: {schema is not None}")
        # If schema is provided, try to return an empty DF with that schema.
        if schema is not None:
            try:
                return spark_sess.createDataFrame([], schema=schema)
            except Exception as schema_e:
                logger.exception(f"Failed even to create empty DataFrame with schema. Error: {schema_e}")
        return None

def _create_empty_df_with_schema(spark_sess: SparkSession, full_table_path: str, columns: Optional[List[str]] = None) -> DataFrame:
    try:
        tbl_schema = spark_sess.read.table(full_table_path).schema
        if columns:
            requested_lower = {c.lower() for c in columns}
            filtered_fields = [f for f in tbl_schema.fields if f.name.lower() in requested_lower]
            schema = T.StructType(filtered_fields) if filtered_fields else T.StructType([T.StructField("info", T.StringType(), True)])
        else:
            schema = tbl_schema
        return spark_sess.createDataFrame([], schema=schema)
    except Exception as e:
        logger.error(f"Could not create empty DataFrame with schema from '{full_table_path}': {e}. Returning DF with 'info' column.")
        return spark_sess.createDataFrame([], schema=T.StructType([T.StructField("info", T.StringType(), True)]))


def parse_to_timestamp_expr(col_name: str, fmt: Optional[str] = None):
    """Coalesce parse to timestamp using preferred format, default parse, and cast fallback."""
    if fmt:
        return F.coalesce(
            F.to_timestamp(F.col(col_name), fmt),
            F.to_timestamp(F.col(col_name)),
            F.col(col_name).cast(T.TimestampType()),
        )
    return F.coalesce(
        F.to_timestamp(F.col(col_name)),
        F.col(col_name).cast(T.TimestampType()),
    )


def _show_partitions_safe(full_table_path: str, partition_col: Optional[str] = None) -> List[str]:
    """
    List partition values (strings) for a table.
    1) SHOW PARTITIONS <table>  -> parse 'col=val' fragments
    2) Fallback: DISTINCT(partition_col) if provided (works for views)
    """
    try:
        res = spark.sql(f"SHOW PARTITIONS `{full_table_path}`")
        vals = []
        for row in res.collect():
            s = row[0] if row and len(row) > 0 else None
            if not s:
                continue
            last = s.split("/")[-1]
            vals.append(last.split("=", 1)[1] if "=" in last else last)
        return vals
    except Exception:
        pass

    if partition_col:
        try:
            df = spark.read.table(full_table_path).select(F.col(partition_col).alias("p")).distinct()
            return [str(r["p"]) for r in df.collect() if r and r["p"] is not None]
        except Exception:
            logger.exception(f"Fallback DISTINCT failed for {full_table_path}.{partition_col}")

    return []

def _month_key(d: date) -> str:
    return f"{d.year:04d}-{d.month:02d}"

def _pick_month_end_partition(partitions: List[str], year: int, month: int) -> Optional[str]:
    """
    From a list of YYYY-MM-DD (or YYYYMMDD) strings, pick the max 'ods' within that month.
    """
    if not partitions:
        return None
    wanted_prefix1 = f"{year:04d}-{month:02d}-"   # ISO
    wanted_prefix2 = f"{year:04d}{month:02d}"     # compact
    # collect all dates in that month we can parse
    candidates: List[Tuple[str, date]] = []
    for p in partitions:
        d = _parse_ymd_safe(p)
        if not d:
            continue
        if d.year == year and d.month == month:
            candidates.append((p, d))
    if not candidates:
        return None
    # pick the max date (month-end or last available within that month)
    candidates.sort(key=lambda t: t[1])
    return candidates[-1][0]

def _resolve_daily_retention_days(sys_name: str, sys_conf: Dict[str, Any], tbl_conf: Dict[str, Any]) -> int:
    """
    Priority: table override > system override (conf) > system default map > global default
    """
    if "daily_retention_days" in tbl_conf:
        return int(tbl_conf["daily_retention_days"])
    if "daily_retention_days" in (sys_conf or {}):
        return int(sys_conf["daily_retention_days"])
    if sys_name in DAILY_RETENTION_BY_SYSTEM:
        return int(DAILY_RETENTION_BY_SYSTEM[sys_name])
    return int(DEFAULT_DAILY_RETENTION_DAYS)

def resolve_effective_partition_for_request(
    *,
    full_table_path: str,
    partition_col: str,
    requested_date_str: str,
    sys_name: str,
    sys_conf: Dict[str, Any],
    tbl_conf: Dict[str, Any],
    today: date
) -> Optional[str]:
    """
    If requested_date within retention window => use the requested day (if present).
    Otherwise => use month-end (max ods for that month). Generic across systems,
    with ICM 45-day default via DAILY_RETENTION_BY_SYSTEM or per-table override.

    Returns an ods string to load, or None if nothing found.
    """
    requested_date = _parse_ymd_safe(requested_date_str)
    if not requested_date:
        logger.warning("resolve_effective_partition: requested_date_str not parseable: %s", requested_date_str)
        return None

    retention_days = _resolve_daily_retention_days(sys_name, sys_conf, tbl_conf)
    cutoff = today - timedelta(days=max(1, retention_days))

    # get all known partitions once
    parts = _show_partitions_safe(full_table_path, partition_col)

    if not parts:
        logger.warning("No partitions found for %s.", full_table_path)
        return None

    # case 1: within daily-window -> prefer exact requested partition if present
    if requested_date >= cutoff:
        if requested_date_str in parts:
            return requested_date_str
        # tolerate alternate formats (e.g., YYYYMMDD)
        compact = requested_date.strftime("%Y%m%d")
        if compact in parts:
            return compact
        # if exact day not present, fall back to closest earlier day within window
        candidates = [(p, _parse_ymd_safe(p)) for p in parts]
        candidates = [(p, d) for (p, d) in candidates if d and cutoff <= d <= requested_date]
        if candidates:
            candidates.sort(key=lambda t: t[1], reverse=True)
            return candidates[0][0]
        return None

    # case 2: beyond daily-window -> pick month-end
    me = _pick_month_end_partition(parts, requested_date.year, requested_date.month)
    if me:
        return me

    # month has no data => choose latest available before requested_date (month-end of earlier month)
    candidates = [(p, _parse_ymd_safe(p)) for p in parts]
    candidates = [(p, d) for (p, d) in candidates if d and d <= requested_date]
    if candidates:
        candidates.sort(key=lambda t: t[1], reverse=True)
        return candidates[0][0]

    return None


def sample_df(df: DataFrame, n: int, seed: int = 42) -> DataFrame:
    """Samples up to 'n' rows from a DataFrame. If DataFrame has fewer than n rows, returns all rows."""
    if df_is_empty(df):
        return df
    try:
        total = df.count() # Materialize count first
        if total <= n:
            logger.debug(f"DataFrame has {total} rows, which is <= sample size {n}. Returning all rows.")
            return df
        # Calculate fraction for sampling to aim for 'n' rows, add a small buffer
        frac = min(1.2 * (n / float(total)), 1.0)
        return df.sample(withReplacement=False, fraction=frac, seed=seed).limit(n)
    except Exception as e:
        logger.exception(f"Error during DataFrame sampling to get {n} rows.")
        # Fallback: return up to n rows directly if sampling fails.
        return df.limit(n)

def _safe_float(val: Any) -> Optional[float]:
    if val is None:
        return None
    try:
        out = float(val)
        if math.isnan(out) or math.isinf(out):
            return None
        return out
    except (ValueError, TypeError):
        return None


def _normalize_to_string(val: Any) -> Optional[str]:
    if val is None:
        return None
    try:
        if isinstance(val, dict):
            return json.dumps(val, ensure_ascii=False, default=str)
        return str(val)
    except TypeError:
        logger.warning("normalize_to_string: non-serializable type %s", type(val).__name__)
        return None




def _normalize_metric_record(rec: Dict[str, Any]) -> Dict[str, Any]:
    """
    Normalize a single metric dict to conform to metrics_history_schema:
      - Coerce numeric fields
      - Uppercase/standardize status
      - Normalize threshold/metric_value to text
      - Ensure required run metadata defaults
    """
    # relies on a global metrics_history_schema (StructType)
    schema_fields = {f.name for f in metrics_history_schema.fields}
    out: Dict[str, Any] = {key: None for key in schema_fields}

    for key, value in rec.items():
        if key not in schema_fields:
            continue

        if key in ("metric_value_num", "reference_value"):
            out[key] = _safe_float(value)

        elif key in ("metric_value", "threshold", "details"):
            out[key] = _normalize_to_string(value)

        elif key == "status":
            try:
                s = str(value).strip().upper()
                valid = {"PASS", "FAIL", "UNCOMPUTABLE", "WARN", "ERROR"}
                if s in valid:
                    out[key] = s
                elif s in {"TRUE", "1", "PASSED"}:
                    out[key] = "PASS"
                elif s in {"FALSE", "0", "FAILED"}:
                    out[key] = "FAIL"
                else:
                    out[key] = "UNCOMPUTABLE"
            except Exception:
                out[key] = "UNCOMPUTABLE"

        elif key == "alert_flag":
            # standardize to "Fail" or None
            if value is None:
                out[key] = None
            else:
                s = str(value).strip().lower()
                out[key] = "Fail" if s in {"fail", "failed", "true", "1", "alert"} else None

        else:
            out[key] = _normalize_to_string(value)

    # Fill defaults if missing
    if not out.get("run_id"): out["run_id"] = RUN_ID
    if not out.get("run_ts"): out["run_ts"] = RUN_TS
    if not out.get("partition_date"): out["partition_date"] = CURR_DATE_STR
    if not out.get("metric_type"): out["metric_type"] = "unknown"
    if not out.get("status"): out["status"] = "PASS"
    if not out.get("country"): out["country"] = REGION

    return out



def sanitize_all_metrics_return_new(all_metrics_list: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    """Sanitizes a list of metric records, catching and reporting individual record errors."""
    sanitized: List[Dict[str, Any]] = []
    for i, rec in enumerate(all_metrics_list):
        try:
            sanitized_rec = _normalize_metric_record(rec)
            sanitized.append(sanitized_rec)
        except Exception as e:
            logger.exception(f"Failed to sanitize metric record at index {i}.")
            # Append an error metric for the failed sanitization attempt
            sanitized.append({
                "run_id": RUN_ID,
                "run_ts": RUN_TS,
                "partition_date": CURR_DATE_STR,
                "metric_type": "metric_sanitization_failure",
                "metric_value": f"Error sanitizing record {_normalize_to_string(rec)}: {e}",
                "status": "UNCOMPUTABLE",
                "alert_flag": "Fail",
                "country": REGION
            })
    return sanitized

def parse_buckets_from_metric_value(metric_value: Union[str, Dict[str, Any]]) -> Tuple[Optional[List[float]], Optional[List[int]]]:
    try:
        obj = json.loads(metric_value) if isinstance(metric_value, str) else metric_value
        if not isinstance(obj, dict):
            logger.warning("parse_buckets: not a dict")
            return None, None
        edges, counts = obj.get("edges"), obj.get("counts")
        if not isinstance(edges, list) or not isinstance(counts, list):
            logger.warning("parse_buckets: missing edges/counts")
            return None, None
        e2 = [_safe_float(x) for x in edges]
        if None in e2:
            logger.warning("parse_buckets: bad edge value")
            return None, None
        c2 = [int(y) for y in counts]
        return e2, c2
    except json.JSONDecodeError:
        logger.warning("parse_buckets: invalid JSON")
        return None, None
    except (TypeError, ValueError):
        logger.warning("parse_buckets: bad element type")
        return None, None
    except Exception:
        logger.exception("parse_buckets: unexpected")
        return None, None

In [0]:
#DRIFT CALCULATIONS HELPERS

def _counts_to_prob_vector(counts: List[float], eps: float = EPSILON) -> List[float]:
    total = float(sum(counts))
    if total <= 0:
        n = max(1, len(counts))
        return [1.0 / n for _ in counts]
    return [max(float(c) / total, eps) for c in counts]


def compute_chi2_and_df_from_maps(cur_map: Dict[str, int], ref_map: Dict[str, int], eps: float = 1e-9) -> Optional[Tuple[float, int]]:
    try:
        if not cur_map or not ref_map:
            logger.warning("Chi2: empty inputs")
            return None
        keys = set(cur_map) | set(ref_map)
        if not keys:
            return 0.0, 0
        chi2 = 0.0
        nonzero_expected = 0
        for k in keys:
            o = float(cur_map.get(k, 0.0))
            e = float(ref_map.get(k, 0.0))
            if e > 0:
                nonzero_expected += 1
            chi2 += ((o - e) ** 2) / (e + eps)
        dof = max(0, nonzero_expected - 1)
        return float(chi2), dof
    except Exception:
        logger.exception("Chi2: failure")
        return None


def _wasserstein_from_buckets(edges: List[float], cur_counts: List[int], ref_counts: List[int]) -> Optional[float]:
    try:
        if not edges or not cur_counts or not ref_counts:
            return None
        n_bins = len(edges) - 1
        if len(cur_counts) != n_bins or len(ref_counts) != n_bins:
            return None
        cur_prob = _counts_to_prob_vector(cur_counts)
        ref_prob = _counts_to_prob_vector(ref_counts)
        cdf_cur = [sum(cur_prob[:i+1]) for i in range(n_bins)]
        cdf_ref = [sum(ref_prob[:i+1]) for i in range(n_bins)]
        w = 0.0
        for i in range(n_bins):
            width = float(edges[i+1]) - float(edges[i])
            diff = abs(cdf_cur[i] - cdf_ref[i])
            w += diff * width
        return float(w)
    except Exception:
        return None

def _jensen_shannon_from_maps(cur_map: Dict[str, int], ref_map: Dict[str, int], top_k: Optional[int] = None, eps: float = EPSILON) -> Optional[float]:
    try:
        cur_map = cur_map or {}
        ref_map = ref_map or {}
        all_keys = set(cur_map.keys()) | set(ref_map.keys())
        if top_k and len(all_keys) > top_k:
            combined = {k: cur_map.get(k, 0) + ref_map.get(k, 0) for k in all_keys}
            top_keys = set(sorted(combined.keys(), key=lambda k: combined[k], reverse=True)[:top_k])
            def reduce_map(m):
                out, other = {}, 0
                for k, v in m.items():
                    if k in top_keys:
                        out[k] = v
                    else:
                        other += v
                if other > 0:
                    out["__OTHER__"] = other
                return out
            cur_map = reduce_map(cur_map)
            ref_map = reduce_map(ref_map)
            all_keys = set(cur_map.keys()) | set(ref_map.keys())
        keys = sorted(all_keys)
        cur_counts = [float(cur_map.get(k, 0)) for k in keys]
        ref_counts = [float(ref_map.get(k, 0)) for k in keys]
        p = _counts_to_prob_vector(cur_counts, eps)
        q = _counts_to_prob_vector(ref_counts, eps)
        m = [(pi + qi) / 2.0 for pi, qi in zip(p, q)]
        def kl(a, b):
            s = 0.0
            for ai, bi in zip(a, b):
                if ai <= 0:
                    continue
                s += ai * math.log(ai / bi)
            return s
        js = 0.5 * (kl(p, m) + kl(q, m))
        return float(js)
    except Exception:
        return None

def _entropy_from_map(m: Dict[str, int], eps: float = EPSILON) -> float:
    counts = [float(v) for v in (m or {}).values()]
    tot = sum(counts)
    if tot <= 0:
        return 0.0
    probs = [max(c / tot, eps) for c in counts]
    return float(-sum(p * math.log(p) for p in probs))

def _topk_churn_and_mass_change(cur_map: Dict[str, int], ref_map: Dict[str, int], k: int = 10) -> Tuple[float, int, int, float]:
    cur_map = cur_map or {}
    ref_map = ref_map or {}
    cur_sorted = sorted(cur_map.items(), key=lambda kv: kv[1], reverse=True)
    ref_sorted = sorted(ref_map.items(), key=lambda kv: kv[1], reverse=True)
    topk_cur = [x for x, _ in cur_sorted[:k]]
    topk_ref = [x for x, _ in ref_sorted[:k]]
    removed = len([x for x in topk_ref if x not in topk_cur])
    added = len([x for x in topk_cur if x not in topk_ref])
    def topk_share(m, keys):
        total = float(sum(m.values()) or 0.0)
        if total == 0.0:
            return 0.0
        return sum(m.get(x, 0) for x in keys) / total
    cur_share = topk_share(cur_map, topk_cur)
    ref_share = topk_share(ref_map, topk_ref)
    mass_delta = cur_share - ref_share
    churn = (removed + added) / float(max(1, k))
    return float(mass_delta), removed, added, float(churn)

In [0]:
def _swap_region_token(path: str, from_region: str, to_region: str) -> str:
    import re
    newp = path
    patterns = [
        (rf"(?i)([_\-\./]){from_region}([_\-\./])", rf"\1{to_region}\2"),  # _sg_ .sg. -sg-
        (rf"(?i)([_\-\./]){from_region}$", rf"\1{to_region}"),             # _sg end
        (rf"(?i)^{from_region}([_\-\./])", rf"{to_region}\1"),             # sg_ start
    ]
    for pat, repl in patterns:
        newp = re.sub(pat, repl, newp)
    return newp

def _parse_ymd_safe(s: str) -> Optional[date]:
    if not s:
        return None
    for fmt in ("%Y-%m-%d", "%Y%m%d"):
        try:
            return datetime.strptime(s, fmt).date()
        except (ValueError, TypeError):
            pass
    return None


def choose_recent_ref_partition(
    partitions: List[str],
    anchor_date: date,
    max_age_days: int,
    *,
    require_iso: bool = False
) -> Optional[str]:
    if not partitions:
        logger.warning("ref_pick: no partitions")
        return None
    cutoff = anchor_date - timedelta(days=max(1, max_age_days))
    parsed: List[Tuple[str, date]] = []
    for p in partitions:
        d = _parse_ymd_safe(p)
        if d:
            if d <= anchor_date:
                parsed.append((p, d))
        elif require_iso:
            logger.warning("ref_pick: reject non-ISO %s", p)
    if parsed:
        pool = [t for t in parsed if t[1] >= cutoff] or parsed
        pool.sort(key=lambda t: t[1], reverse=True)
        return pool[0][0]
    if require_iso:
        return None
    return sorted(partitions)[-1]


# ---- partition value formatting (matches storage format/type) -------------------------
def _format_partition_literal(
    partition_value: str,
    pfield_dtype: T.DataType,
    pcol: str,
    date_columns: Optional[Dict[str, str]]
) -> F.Column:
    lit_expr = F.lit(partition_value)
    fmt = (date_columns or {}).get(pcol)
    if isinstance(pfield_dtype, T.DateType):
        # e.g., 'yyyy-MM-dd' stored as DATE
        return F.to_date(lit_expr) if not fmt or fmt == "yyyy-MM-dd" else F.to_date(lit_expr, fmt)
    if isinstance(pfield_dtype, T.TimestampType):
        return F.to_timestamp(lit_expr) if not fmt or fmt.startswith("yyyy-MM-dd") else F.to_timestamp(lit_expr, fmt)
    if isinstance(pfield_dtype, (T.IntegerType, T.LongType, T.ShortType, T.FloatType, T.DoubleType, T.DecimalType, T.BooleanType)):
        return lit_expr.cast(pfield_dtype)
    if isinstance(pfield_dtype, T.StringType) and fmt:
        # Convert "YYYY-MM-DD" to target string format (e.g., 'yyyyMMdd' for HOGAN.process_date)
        return F.date_format(F.to_date(lit_expr, "yyyy-MM-dd"), fmt)
    return lit_expr

In [0]:
def compute_psi_from_counts(cur_counts: List[int], ref_counts: List[int], eps: float = 1e-9) -> Optional[float]:
    try:
        if not cur_counts or not ref_counts or len(cur_counts) != len(ref_counts):
            logger.warning("PSI: invalid inputs")
            return None
        c = [max(eps, float(x)) for x in cur_counts]
        r = [max(eps, float(x)) for x in ref_counts]
        sc, sr = sum(c), sum(r)
        if sc <= 0 or sr <= 0:
            logger.warning("PSI: zero sums")
            return None
        psi = 0.0
        for ci, ri in zip(c, r):
            pc = ci / sc
            pr = ri / sr
            psi += (pc - pr) * math.log(pc / pr)
        return float(psi)
    except Exception:
        logger.exception("PSI: failure")
        return None

def wasserstein_from_buckets(edges: List[float], cur_counts: List[int], ref_counts: List[int], eps: float = 1e-12) -> Optional[float]:
    try:
        if not edges or not cur_counts or not ref_counts:
            logger.warning("W1D: missing inputs")
            return None
        n = len(edges) - 1
        if n <= 0 or len(cur_counts) != n or len(ref_counts) != n:
            logger.warning("W1D: shape mismatch")
            return None
        for i in range(n):
            if not (edges[i+1] > edges[i]):
                logger.warning("W1D: non-monotonic edges at %d", i)
                return None
        c = [max(0.0, float(x)) for x in cur_counts]
        r = [max(0.0, float(x)) for x in ref_counts]
        sc, sr = sum(c), sum(r)
        if sc <= 0 or sr <= 0:
            logger.warning("W1D: zero sums")
            return None
        cp = [x / sc for x in c]
        rp = [x / sr for x in r]
        cdf_c, cdf_r, acc_c, acc_r = [], [], 0.0, 0.0
        for i in range(n):
            acc_c += cp[i]; cdf_c.append(acc_c)
            acc_r += rp[i]; cdf_r.append(acc_r)
        w = 0.0
        for i in range(n):
            w += abs(cdf_c[i] - cdf_r[i]) * max(eps, (edges[i+1] - edges[i]))
        return float(w)
    except Exception:
        logger.exception("W1D: failure")
        return None



# def _compute_chi2_from_maps(cur_map: Dict[str, int], ref_map: Dict[str, int], eps: float = EPSILON) -> Optional[float]:
#     """
#     Computes Chi-Squared statistic from current and reference category counts.
#     This is related to CSI.
#     """
#     if cur_map is None and ref_map is None:
#         return None
#     cur_map = cur_map or {}
#     ref_map = ref_map or {}

#     # Get all unique categories present in either map
#     all_cats = sorted(set(cur_map.keys()) | set(ref_map.keys()))

#     cur_counts = [int(cur_map.get(c, 0)) for c in all_cats]
#     ref_counts = [int(ref_map.get(c, 0)) for c in all_cats]

#     total_cur = float(sum(cur_counts))
#     total_ref = float(sum(ref_counts))
#     total_all = total_cur + total_ref

#     if total_all == 0:
#         return 0.0 # No data, Chi2 is 0.

#     chi2 = 0.0
#     for i in range(len(all_cats)):
#         observed_cur = float(cur_counts[i])
#         observed_ref = float(ref_counts[i])

#         # Expected counts calculation for independence test
#         # Expected_cur = total_cur * (proportion of this category in ref_map or overall)
#         # A more standard way is to use marginal totals:
#         # Expected_ij = (row_total * column_total) / grand_total
#         # Here, row_total = total_cur if current, total_ref if reference
#         # column_total = sum of counts for this category across both
#         # grand_total = total_all

#         # Simplified expectation calculation assuming marginals:
#         # E_cur = (total_cur * (observed_cur + observed_ref)) / total_all
#         # E_ref = (total_ref * (observed_cur + observed_ref)) / total_all

#         # If the dataset has categories only in current or reference, the total for that category across both may be non-zero.
#         category_total = observed_cur + observed_ref
#         if category_total == 0:
#             continue # No counts for this category, skip

#         # Expected count for current partition, given marginals
#         expected_cur = (total_cur * category_total) / total_all if total_all > 0 else 0.0
#         # Expected count for reference partition, given marginals
#         expected_ref = (total_ref * category_total) / total_all if total_all > 0 else 0.0

#         # Add epsilon to expected values to prevent division by zero
#         if expected_cur < eps: expected_cur = eps
#         if expected_ref < eps: expected_ref = eps

#         try:
#             # (Observed - Expected)^2 / Expected
#             if observed_cur > 0 or expected_cur > eps: # Only add if there's an observation or a non-negligible expectation
#                 chi2 += ((observed_cur - expected_cur) ** 2) / expected_cur
#             if observed_ref > 0 or expected_ref > eps: # Only add if there's an observation or a non-negligible expectation
#                 chi2 += ((observed_ref - expected_ref) ** 2) / expected_ref
#         except ZeroDivisionError:
#             logger.warning(f"Chi2 calculation encountered division by zero for category '{all_cats[i]}'.")
#             continue

#     return float(chi2)

def compute_numeric_psi_and_buckets(
    col: str,
    cur_df: DataFrame,
    ref_df: Optional[DataFrame],
    bins: int = 10,
    sample_size: int = 50000,
) -> Tuple[Optional[float], Optional[List[float]], Optional[List[int]], Optional[List[int]]]:
    """
    Returns: (psi, edges, cur_counts, ref_counts). Never uses slice keys.
    - Edges are numeric list, counts are int list.
    """
    try:
        # prepare numeric column (strip commas, cast)
        cur_num = (cur_df
                   .select(F.regexp_replace(F.col(col), ",", "").cast(T.DoubleType()).alias(col))
                   .where(F.col(col).isNotNull()))
        if sample_size and sample_size > 0:
            cur_num = cur_num.limit(int(sample_size))

        if ref_df is not None:
            ref_num = (ref_df
                       .select(F.regexp_replace(F.col(col), ",", "").cast(T.DoubleType()).alias(col))
                       .where(F.col(col).isNotNull()))
            if sample_size and sample_size > 0:
                ref_num = ref_num.limit(int(sample_size))
        else:
            ref_num = None

        # choose edges from ref if present, else from current
        source_for_edges = ref_num if ref_num is not None else cur_num
        if source_for_edges is None or source_for_edges.rdd.isEmpty():
            return None, None, None, None

        # robust edges via approx quantiles
        probs = [i / float(bins) for i in range(bins + 1)]
        q = source_for_edges.approxQuantile(col, probs, 0.001)
        # ensure strictly increasing and finite
        edges = [float(x) for x in q if x is not None and math.isfinite(float(x))]
        # if duplicates, spread minimally
        if len(edges) >= 2:
            dedup = [edges[0]]
            for x in edges[1:]:
                if x <= dedup[-1]:
                    x = dedup[-1] + 1e-9
                dedup.append(x)
            edges = dedup
        if len(edges) < 2:
            return None, None, None, None

        bucketizer = Bucketizer(splits=edges, inputCol=col, outputCol="_dq_bucket", handleInvalid="skip")

        cur_b = (bucketizer.transform(cur_num)
                 .groupBy("_dq_bucket").count()
                 .withColumnRenamed("count", "cnt"))
        cur_rows = {int(r["_dq_bucket"]): int(r["cnt"]) for r in cur_b.collect()}

        n_bins = len(edges) - 1
        cur_counts = [cur_rows.get(i, 0) for i in range(n_bins)]

        ref_counts = None
        psi = None
        if ref_num is not None and not ref_num.rdd.isEmpty():
            ref_b = (bucketizer.transform(ref_num)
                     .groupBy("_dq_bucket").count()
                     .withColumnRenamed("count", "cnt"))
            ref_rows = {int(r["_dq_bucket"]): int(r["cnt"]) for r in ref_b.collect()}
            ref_counts = [ref_rows.get(i, 0) for i in range(n_bins)]
            psi = _compute_psi_from_counts(cur_counts, ref_counts, EPSILON)

        return psi, edges, cur_counts, ref_counts
    except Exception:
        logger.exception("compute_numeric_psi_and_buckets failed")
        return None, None, None, None


def compute_categorical_csi_and_buckets(
    col: str,
    cur_df: DataFrame,
    ref_df: Optional[DataFrame],
    sample_size: int = 50000,
    max_cardinality: int = 5000,
) -> Tuple[Optional[float], Optional[Dict[str, int]], Optional[Dict[str, int]]]:
    """
    Returns: (chi2_value, cur_map, ref_map)
    - cur_map/ref_map: {category (str): count (int)}
    - trims to max_cardinality by top frequency to avoid blowups.
    """
    try:
        def _counts(df: DataFrame) -> Dict[str, int]:
            if df is None: return {}
            base = df.select(F.coalesce(F.col(col).cast(T.StringType()), F.lit("NULL")).alias(col))
            if sample_size and sample_size > 0:
                base = base.limit(int(sample_size))
            rows = (base.groupBy(col).count().collect())
            m = { (r[col] if r[col] is not None else "NULL"): int(r["count"]) for r in rows }
            # trim to top-N by freq if too many categories
            if len(m) > max_cardinality:
                tops = sorted(m.items(), key=lambda kv: kv[1], reverse=True)[:max_cardinality]
                m = dict(tops)
            return m

        cur_map = _counts(cur_df)
        ref_map = _counts(ref_df) if ref_df is not None else {}

        if not cur_map or not ref_map:
            return None, cur_map, ref_map

        # align keys
        keys = sorted(set(cur_map.keys()) | set(ref_map.keys()))
        cur = [float(cur_map.get(k, 0)) for k in keys]
        ref = [float(ref_map.get(k, 0)) for k in keys]

        # classical chi-square (without Yates); add eps to avoid div-by-zero
        eps = EPSILON
        chi2 = 0.0
        for i in range(len(keys)):
            diff = cur[i] - ref[i]
            denom = ref[i] + eps
            chi2 += (diff * diff) / denom

        return float(chi2), cur_map, ref_map
    except Exception:
        logger.exception("compute_categorical_csi_and_buckets failed")
        return None, None, None

def _tokenize_expr_for(colname: str) -> str:
    # Use the PROTEGRITY_UDF / PROTEGRITY_POLICY strings from config
    return f"{PROTEGRITY_UDF}(trim(`{colname}`),'{PROTEGRITY_POLICY}')"

def tokenize_col_if_needed(df: DataFrame, raw_col_candidates: List[str], token_col_name: str = "harmonized_key") -> DataFrame:
    """
    Attempt to tokenise first available column in raw_col_candidates.
    Always returns a DF that contains token_col_name (null if tokenization not possible).
    """
    resolved = None
    for c in raw_col_candidates:
        try:
            rc = resolve_col_name(df, c)
        except Exception:
            rc = None
        if rc:
            resolved = rc
            break

    existing_token_col = resolve_col_name(df, token_col_name)
    if existing_token_col and existing_token_col == token_col_name:
        return df

    protegrity_available = bool(PROTEGRITY_UDF) and bool(PROTEGRITY_POLICY)
    if resolved and protegrity_available:
        try:
            logger.info(f"Applying tokenization to column '{resolved}' -> '{token_col_name}' using {PROTEGRITY_UDF}.")
            if existing_token_col:
                df = df.drop(existing_token_col)
            return df.withColumn(token_col_name, F.expr(_tokenize_expr_for(resolved)))
        except Exception:
            logger.exception(f"Protegrity UDF failed for column '{resolved}'. Falling back to null '{token_col_name}'.")
            return df.withColumn(token_col_name, F.lit(None).cast(T.StringType()))
    else:
        if not resolved:
            logger.debug(f"No candidate column found among {raw_col_candidates} for tokenization.")
        if not protegrity_available:
            logger.debug("Protegrity not configured; adding null token column.")
        return df.withColumn(token_col_name, F.lit(None).cast(T.StringType()))


# -----------------------
# Partitions / IO Helpers
def _parse_ymd_safe(s: str) -> Optional[date]:
    """Best-effort YYYY-MM-DD parser; returns None if parsing fails."""
    try:
        return datetime.strptime(s[:10], "%Y-%m-%d").date()
    except Exception:
        return None

In [0]:
# --------------------------------------------------------------------------------
# Load a single partition safely (typed literal; resolved columns)
# --------------------------------------------------------------------------------
def load_partition_df(
    spark_sess: SparkSession,
    full_table_path: str,
    partition_col: str,
    partition_value: str,
    columns: Optional[List[str]] = None,
    tbl_conf: Optional[Dict[str, Any]] = None,
) -> Tuple[Optional[DataFrame], str]:
    """
    Returns (df, status) where status ∈ {'loaded','missing','error'}.
    Ensures partition literal matches the storage type/format.
    """
    try:
        base_df = spark_sess.read.table(full_table_path)
        tbl_schema = base_df.schema
        dummy = spark_sess.createDataFrame([], schema=tbl_schema)

        pcol_resolved = resolve_col_name(dummy, partition_col)
        if not pcol_resolved:
            logger.error(f"Partition column '{partition_col}' not found in '{full_table_path}'.")
            return None, "error"

        # pick dtype and format partition literal accordingly
        pfield = next((f for f in tbl_schema.fields if f.name == pcol_resolved), None)
        if pfield is None:
            logger.error(f"Resolved partition column '{pcol_resolved}' missing in schema for '{full_table_path}'.")
            return None, "error"

        date_fmt_map = (tbl_conf or {}).get("date_columns") if isinstance(tbl_conf, dict) else {}
        lit_expr = _format_partition_literal(partition_value, pfield.dataType, pcol_resolved, date_fmt_map)

        df = base_df.filter(F.col(pcol_resolved) == lit_expr)

        if columns:
            keep_cols = resolve_col_list(df, columns)
            if pcol_resolved not in keep_cols:
                keep_cols.append(pcol_resolved)
            keep_cols = list(dict.fromkeys(keep_cols))
            if keep_cols:
                df = df.select(*[F.col(c) for c in keep_cols])

        if df_is_empty(df):
            logger.warning(f"Partition '{pcol_resolved}={partition_value}' in '{full_table_path}' is empty.")
            return _create_empty_df_with_schema(spark_sess, full_table_path, columns), "missing"

        return df, "loaded"

    except Exception as e:
        logger.exception(f"Error loading {full_table_path} @ {partition_col}={partition_value}: {e}")
        return None, "error"



def is_numeric_column_by_schema(df: DataFrame, colname: str) -> bool:
    """Checks if a column in the DataFrame schema is of a numeric type."""
    if not colname or df is None:
        return False
    for f in df.schema.fields:
        if f.name.lower() == colname.lower():
            # Consider standard numeric types
            return isinstance(f.dataType, T.NumericType)
    return False

# ========================
# DQMonitor
# ========================
from typing import Any, Dict, List, Optional, Tuple, Union
from datetime import date, datetime, timedelta
from pyspark.sql import SparkSession, DataFrame, functions as F, types as T

class DQMonitor:
    """
    Core DQ monitor:
    - Caches loaded partitions and counts (single load per table/partition)
    - Emits normalized metrics via _emit_metric
    - Uses robust partition discovery and region fallback
    - Safe, vectorized checks (row-count, completeness, uniqueness, range, date/latency, join-consistency)
    """

    def __init__(
        self,
        spark_session: SparkSession,
        table_config: Optional[Dict[str, Dict[str, Any]]] = None,
        full_table_config: Optional[Dict[str, Dict[str, Any]]] = None,
        cross_conf: Optional[List[Dict[str, Any]]] = None,
        curr_date_obj: Optional[date] = None,
        static_thr: Optional[Dict[str, Union[float, int]]] = None,
        region: Optional[str] = None,
    ):
        self.spark = spark_session

        supplied_conf = table_config if table_config is not None else full_table_config
        if supplied_conf is None:
            supplied_conf = get_table_config(region)

        self.table_config = validate_and_normalize_table_config(supplied_conf, (region or "sg"))
        self.full_table_config = validate_and_normalize_table_config(supplied_conf, (region or "sg"))

        self.cross_conf = cross_conf or []

        self.curr_date_obj = curr_date_obj or date.today()
        self.curr_date_str = self.curr_date_obj.strftime("%Y-%m-%d")

        self.static_thr = static_thr or {}
        self.region = (region or "sg").lower()

        # Caches
        self.all_metrics: List[Dict[str, Any]] = []

        # keep cache keys consistent with your current loader signature (sys, tbl, partition)
        self.df_cache: Dict[Tuple[str, str, str], Optional[DataFrame]] = {}
        self.df_count_cache: Dict[Tuple[str, str, str], int] = {}

        self.ref_date_map: Dict[str, Optional[str]] = {}
        self.col_stats_cache: Dict[Tuple[str, str, str, str], Dict[str, Optional[float]]] = {}

        self.metrics_history_df_cache: Optional[DataFrame] = None
        self.metrics_history_loaded = False

    # -------------------------------------------------------------------------
    # Metric helpers
    # -------------------------------------------------------------------------
    def _emit_metric(self, rec: Dict[str, Any], partition_date: Optional[str] = None):
        """Append a normalized metric record (fills run metadata)."""
        base = {
            "run_id": RUN_ID,
            "partition_date": partition_date or self.curr_date_str,
            "run_ts": RUN_TS,
            "country": REGION,
        }
        base.update(rec)
        self.all_metrics.append(base)

    def _mark_uncomputable(
        self,
        metric_type: str,
        sys: str,
        tbl: str,
        col_name: Optional[str] = None,
        partition_date: Optional[str] = None,
        **kwargs,
    ):
        rec = {
            "metric_type": metric_type,
            "source_system": sys,
            "table_name": tbl,
            "column_name": col_name,
            "metric_value": "uncomputable",
            "metric_value_num": None,
            "status": "UNCOMPUTABLE",
            "alert_flag": "Fail",
        }
        rec.update(kwargs)
        self._emit_metric(rec, partition_date=partition_date)

    # Instance wrapper so existing code can call self._normalize_metric_record(...)
    def _normalize_metric_record(self, rec: Dict[str, Any]) -> Dict[str, Any]:
        return _normalize_metric_record(rec)

    def check_cross_system_join_consistency(self, *args, **kwargs):
        """Back-compat alias -> use key consistency check."""
        return self.check_cross_system_key_consistency(*args, **kwargs)

    def resolve_col_name(self, df: DataFrame, name: Optional[str]) -> Optional[str]:
        if not name or df is None:
            return None
        key = tuple(df.columns)
        lm = self._schema_lower_cache.get(key)
        if lm is None:
            lm = _lower_map(df.columns)
            self._schema_lower_cache[key] = lm
        return lm.get(name.lower())

    def resolve_col_list(self, df: DataFrame, names: List[str]) -> List[str]:
        return [r for n in (names or []) if (r := self.resolve_col_name(df, n))]


    # ------------------------------------------------------------------------
    # Config / partitions
    # -------------------------------------------------------------------------
    def _get_table_conf(self, sys: str, tbl: str) -> Optional[Dict[str, Any]]:
        """Safe accessor for a (system, table) config; injects system default partition col if missing."""
        cfg_root = self.table_config or self.full_table_config or {}
        sys_conf = (cfg_root.get(sys) or {})
        tables = sys_conf.get("tables") or {}
        tconf = tables.get(tbl)
        if not isinstance(tconf, dict):
            return None
        if "default_partition_col" not in tconf:
            tconf = dict(tconf)  # shallow copy
            tconf["default_partition_col"] = sys_conf.get("default_partition_col", "ods")
        return tconf

    def _get_ref_partition_date(self, full_table_path: str, partition_col: Optional[str] = None) -> Optional[str]:
        """
        Pick a recent reference partition for drift/baselines (cached by full path).
        - Prefers ≤ 'today' and within configured max age.
        - Falls back to region-swapped path if configured and primary has no partitions.
        """
        if full_table_path in self.ref_date_map:
            return self.ref_date_map[full_table_path]

        parts = _show_partitions_safe(full_table_path, partition_col)
        # Optional region fallback if none found at primary path
        if not parts:
            try:
                current_region = (REGION or "sg").lower()
                for alt in [r for r in ["sg", "hk"] if r != current_region]:
                    alt_path = _swap_region_token(full_table_path, current_region, alt)
                    if alt_path != full_table_path:
                        parts = _show_partitions_safe(alt_path, partition_col)
                        if parts:
                            break
            except Exception:
                logger.exception("Region fallback during ref partition lookup failed.")

        if not parts:
            logger.warning(f"No partitions found for {full_table_path}; reference will be None.")
            self.ref_date_map[full_table_path] = None
            return None

        max_age = int(self.static_thr.get("max_ref_partition_age_days", MAX_REF_AGE_DAYS))
        ref = choose_recent_ref_partition(parts, self.curr_date_obj, max_age)
        self.ref_date_map[full_table_path] = ref
        return ref

    # -------------------------------------------------------------------------
    # Load & cache a single partition (with region fallback already handled in loader)
    # -------------------------------------------------------------------------
#     def _load_and_cache_df(
#     self,
#     sys: str,
#     tbl: str,
#     partition_date: str,
#     tbl_conf: Optional[Dict[str, Any]] = None,
# ) -> Optional[DataFrame]:
#         """
#         Load a single partition for (sys, tbl, partition_date) and cache both df and count.

#         Guardrail:
#           - If source system is EBBS and region is HK, skip loading entirely.
#         """
#         cache_key = (sys, tbl, partition_date)
#         if cache_key in self.df_cache:
#             return self.df_cache[cache_key]

#         # --- Skip EBBS for HK region ---
#         current_region = (getattr(self, "region", None) or DEFAULT_REGION).strip().lower()
#         if sys.upper() == "EBBS" and current_region == "hk":
#             logger.info(f"Skipping load for {sys}.{tbl}@{partition_date} (region={current_region}).")
#             self.df_cache[cache_key] = None
#             self.df_count_cache[cache_key] = 0
#             return None

#         tbl_conf = tbl_conf or self._get_table_conf(sys, tbl)
#         if not tbl_conf:
#             logger.error(f"Table config for {sys}.{tbl} not found.")
#             self.df_cache[cache_key] = None
#             self.df_count_cache[cache_key] = 0
#             return None

#         full_path = tbl_conf.get("full_table_path")
#         p_col = _system_partition_col(sys, tbl_conf)
#         cols_decl = tbl_conf.get("columns") or []
#         cols_to_load = list(cols_decl.keys()) if isinstance(cols_decl, dict) else list(cols_decl)

#         df, status = load_partition_df(
#             self.spark, full_path, p_col, partition_date, cols_to_load, tbl_conf=tbl_conf
#         )

#         # Region fallback if empty/missing/error AND a region hint is present
#         if (status in ("error", "missing") or df is None or df_is_empty(df)) and tbl_conf.get("region_hint"):
#             try:
#                 cur_region = (tbl_conf.get("region") or "sg").strip().lower()
#                 for alt_region in [r for r in ["sg", "hk"] if r != cur_region]:
#                     attempt_path = _swap_region_token(full_path, cur_region, alt_region)
#                     if attempt_path != full_path:
#                         logger.info(f"[{sys}.{tbl}] trying alt region path: {attempt_path}")
#                         df_alt, status_alt = load_partition_df(
#                             self.spark, attempt_path, p_col, partition_date, cols_to_load, tbl_conf=tbl_conf
#                         )
#                         if status_alt == "loaded" and df_alt is not None and not df_is_empty(df_alt):
#                             df, status, full_path = df_alt, "loaded", attempt_path
#                             break
#             except Exception:
#                 logger.exception("Region fallback error; continuing with primary load result.")

#         # Cache df + its count
#         self.df_cache[cache_key] = df
#         self.df_count_cache[cache_key] = 0 if (df is None or df_is_empty(df)) else df.count()

#         return df
    def _load_and_cache_df(
        self,
        sys: str,
        tbl: str,
        partition_date: str,
        tbl_conf: Optional[Dict[str, Any]] = None,
    ) -> Optional[DataFrame]:
        """
        Load a single partition for (sys,tbl,partition_date) with daily-vs-month-end logic.
        Caches df and its count.
        """
        cache_key = (sys, tbl, partition_date)
        if cache_key in self.df_cache:
            return self.df_cache[cache_key]

        tbl_conf = tbl_conf or self._get_table_conf(sys, tbl)
        if not tbl_conf:
            logger.error(f"Table config for {sys}.{tbl} not found.")
            self.df_cache[cache_key] = None
            self.df_count_cache[cache_key] = 0
            return None

        full_path = tbl_conf.get("full_table_path")
        p_col = _system_partition_col(sys, tbl_conf)
        cols_decl = tbl_conf.get("columns") or []
        cols_to_load = list(cols_decl.keys()) if isinstance(cols_decl, dict) else list(cols_decl)

        # figure effective partition to load (daily vs month-end)
        try:
            sys_conf = self.table_config.get(sys, {})  # to allow system-level overrides
            effective_partition = resolve_effective_partition_for_request(
                full_table_path=full_path,
                partition_col=p_col,
                requested_date_str=partition_date,
                sys_name=sys,
                sys_conf=sys_conf,
                tbl_conf=tbl_conf,
                today=self.curr_date_obj  # run anchor
            )
        except Exception:
            logger.exception("Failed to resolve effective partition for %s.%s @ %s", sys, tbl, partition_date)
            effective_partition = None

        if not effective_partition:
            logger.warning("No effective partition found for %s.%s @ %s", sys, tbl, partition_date)
            self.df_cache[cache_key] = None
            self.df_count_cache[cache_key] = 0
            return None

        # now actually load that partition
        df, status = load_partition_df(
            self.spark, full_path, p_col, effective_partition, cols_to_load, tbl_conf=tbl_conf
        )

        # Region fallback if empty/missing/error AND a region hint is present (unchanged)
        if (status in ("error", "missing") or df is None or df_is_empty(df)) and tbl_conf.get("region_hint"):
            try:
                cur_region = (tbl_conf.get("region") or "sg").strip().lower()
                for alt_region in [r for r in ["sg", "hk"] if r != cur_region]:
                    attempt_path = _swap_region_token(full_path, cur_region, alt_region)
                    if attempt_path != full_path:
                        logger.info(f"[{sys}.{tbl}] trying alt region path: {attempt_path}")
                        df_alt, status_alt = load_partition_df(
                            self.spark, attempt_path, p_col, effective_partition, cols_to_load, tbl_conf=tbl_conf
                        )
                        if status_alt == "loaded" and df_alt is not None and not df_is_empty(df_alt):
                            df, status, full_path = df_alt, "loaded", attempt_path
                            break
            except Exception:
                logger.exception("Region fallback error; continuing with primary load result.")

        # Cache df + cached count (you said you don’t need defensive guard, so direct count)
        self.df_cache[cache_key] = df
        self.df_count_cache[cache_key] = 0 if (df is None or df_is_empty(df)) else df.count()

        # Log which partition we ended up using for transparency
        logger.info("Loaded %s.%s using effective partition %s (requested %s)", sys, tbl, effective_partition, partition_date)

        return df



    # -------------------------------------------------------------------------
    # Column stats (cached)
    # -------------------------------------------------------------------------
    def _get_column_stats(
        self,
        sys: str,
        tbl: str,
        partition_date: str,
        df: Optional[DataFrame],
        col_name: str
    ) -> Dict[str, Optional[float]]:
        """
        Basic stats for a column: non_null_count, distinct_count, min, max
        (cached per (sys,tbl,partition,col))
        """
        if df is None or df_is_empty(df):
            return {"non_null_count": 0, "distinct_count": 0, "min": None, "max": None}

        resolved_col = resolve_col_name(df, col_name)
        if not resolved_col:
            logger.warning(f"Column '{col_name}' not found in DataFrame for stats.")
            return {"non_null_count": 0, "distinct_count": 0, "min": None, "max": None}

        key = (sys, tbl, partition_date, resolved_col)
        if key in self.col_stats_cache:
            return self.col_stats_cache[key]

        stats_result: Dict[str, Optional[float]] = {"non_null_count": None, "distinct_count": None, "min": None, "max": None}
        try:
            agg_exprs = [
                F.sum(F.when(F.col(resolved_col).isNotNull(), 1).otherwise(0)).alias("non_null_count"),
                F.countDistinct(F.col(resolved_col)).alias("distinct_count"),
                F.min(F.col(resolved_col)).alias("minv"),
                F.max(F.col(resolved_col)).alias("maxv"),
            ]
            agg_result = df.agg(*agg_exprs).first()
            if agg_result:
                stats_result["non_null_count"] = float(agg_result["non_null_count"] or 0.0)
                stats_result["distinct_count"] = float(agg_result["distinct_count"] or 0.0)
                stats_result["min"] = _safe_float(agg_result["minv"])
                stats_result["max"] = _safe_float(agg_result["maxv"])
        except Exception:
            logger.exception(f"Column stats failed for {sys}.{tbl}@{partition_date}.{resolved_col}.")
            stats_result = {"non_null_count": None, "distinct_count": None, "min": None, "max": None}

        self.col_stats_cache[key] = stats_result
        return stats_result

    # -------------------------------------------------------------------------
    # Row count (current, drift, dynamic)
    # -------------------------------------------------------------------------
    def compute_rowcount_dynamic_threshold(
        self,
        sys: str,
        tbl: str,
        tbl_conf: Dict[str, Any],
        lookback_days: int,
    ) -> Optional[Dict[str, float]]:
        spark_sess = self.spark
        full_path = tbl_conf.get("full_table_path")
        pcol = _system_partition_col(sys, tbl_conf)
        if not full_path or not pcol:
            logger.warning(f"Missing path/partition_col for {sys}.{tbl}")
            return None

        parts = _show_partitions_safe(full_path, pcol)
        if not parts:
            logger.warning(f"No partitions for {sys}.{tbl}; dynamic threshold unavailable.")
            return None

        end_date = CURR_DATE - timedelta(days=1)
        start_date = end_date - timedelta(days=max(1, lookback_days) - 1)
        window_parts = [p for p in parts if (d := _parse_ymd_safe(p)) and start_date <= d <= end_date]
        if not window_parts:
            logger.warning(f"No partitions within lookback window for {sys}.{tbl}.")
            return None

        part_vals = ",".join([f"'{v}'" for v in window_parts])
        try:
            counts_df = spark_sess.sql(
                f"SELECT `{pcol}` as partition_value, COUNT(1) as cnt "
                f"FROM `{full_path}` WHERE `{pcol}` IN ({part_vals}) GROUP BY `{pcol}`"
            )
            rows = [float(r["cnt"]) for r in counts_df.collect()]
            if not rows:
                return None
            n = len(rows)
            mean_v = sum(rows) / n
            var_v = (sum((x - mean_v) ** 2 for x in rows) / n) if n > 1 else 0.0
            std_v = math.sqrt(var_v)
            lower = max(0.0, mean_v - 3 * std_v)
            upper = mean_v + 3 * std_v
            # (Optional) include n_samples if your caller emits it
            return {"mean": mean_v, "stddev": std_v, "lower": lower, "upper": upper, "n_samples": n}
        except Exception:
            logger.exception(f"Rowcount dynamic threshold failed for {sys}.{tbl}.")
            return None

        dynamic_thresholds = self.compute_rowcount_dynamic_threshold(sys, tbl, tbl_conf, lookback_days=lookback_days)


    def check_row_count(
        self, sys: str, tbl: str, tbl_conf: Dict[str, Any], cur_df: Optional[DataFrame], partition_date: str
    ):
        """Emit current row count; check drift vs reference; compare to dynamic band."""
        if cur_df is None:
            logger.warning(f"Current DF is None for {sys}.{tbl}@{partition_date}.")
            self._mark_uncomputable("row_count", sys, tbl, partition_date=partition_date)
            return

        current_count = self.df_count_cache.get((sys, tbl, partition_date), 0)

        # (1) current row count
        self._emit_metric({
            "metric_type": "row_count",
            "source_system": sys,
            "table_name": tbl,
            "metric_value_num": float(current_count),
            "metric_value": str(current_count),
            "status": "PASS",
        }, partition_date=partition_date)

        # (2) drift vs recent reference
        ref_date = self._get_ref_partition_date(tbl_conf.get("full_table_path"), tbl_conf.get("partition_col"))
        ref_count = None
        if ref_date:
            ref_df = self._load_and_cache_df(sys, tbl, ref_date)
            if ref_df is not None:
                rc = self.df_count_cache.get((sys, tbl, ref_date), 0)
                if rc > 0:
                    ref_count = rc
                else:
                    logger.warning(f"Reference partition {ref_date} for {sys}.{tbl} has zero rows.")

            self._emit_metric({
                "metric_type": "row_count",
                "source_system": sys,
                "table_name": tbl,
                "reference_value": float(ref_count) if ref_count is not None else None,
                "status": "PASS" if ref_count is not None else "UNCOMPUTABLE",
            }, partition_date=partition_date)

            if ref_count:
                drift_abs = abs(current_count - ref_count)
                drift_pct = drift_abs / ref_count
                threshold_pct = float(self.static_thr.get("row_count_drift_pct", 0.01))
                status = "PASS" if drift_pct <= threshold_pct else "FAIL"

                self._emit_metric({
                    "metric_type": "row_count_drift",
                    "source_system": sys,
                    "table_name": tbl,
                    "metric_value_num": drift_pct,
                    "metric_value": f"{drift_pct:.2%}",
                    "threshold": f"<={threshold_pct:.2%}",
                    "reference_value": float(ref_count),
                    "status": status,
                    "alert_flag": "Fail" if status == "FAIL" else None,
                }, partition_date=partition_date)
            else:
                self._mark_uncomputable("row_count_drift", sys, tbl, partition_date=partition_date)
        else:
            logger.warning(f"No reference partition found for {sys}.{tbl}.")
            self._mark_uncomputable("row_count_drift", sys, tbl, partition_date=partition_date)

        # (3) dynamic band (mean±3σ) over history
        lookback_days = int(self.static_thr.get("historical_partition_window", 30))
        dyn = self.compute_rowcount_dynamic_threshold(sys, tbl, tbl_conf, lookback_days=lookback_days)
        if dyn:
            lower, upper, mean, std, n = dyn["lower"], dyn["upper"], dyn["mean"], dyn["std"], dyn["n_samples"]
            status = "PASS" if lower <= current_count <= upper else "FAIL"
            thr_str = f"mean={mean:.2f},std={std:.2f},lower={int(lower)},upper={int(upper)},n_samples={int(n)}"
            self._emit_metric({
                "metric_type": "row_count_dynamic_threshold",
                "source_system": sys,
                "table_name": tbl,
                "metric_value_num": float(current_count),
                "metric_value": f"{current_count}",
                "threshold": thr_str,
                "reference_value": mean,
                "status": status,
                "alert_flag": "Fail" if status == "FAIL" else None,
            }, partition_date=partition_date)
        else:
            self._mark_uncomputable("row_count_dynamic_threshold", sys, tbl, partition_date=partition_date)

    # -------------------------------------------------------------------------
    # Completeness
    # -------------------------------------------------------------------------
    def check_completeness(
        self, sys: str, tbl: str, tbl_conf: Dict[str, Any], cur_df: Optional[DataFrame], partition_date: str
    ):
        """Percent non-null/non-empty per configured columns + overall."""
        if cur_df is None or df_is_empty(cur_df):
            logger.warning(f"DF None/empty for {sys}.{tbl}@{partition_date}.")
            self._mark_uncomputable("completeness", sys, tbl, partition_date=partition_date)
            return

        total_rows = self.df_count_cache.get((sys, tbl, partition_date), 0)
        if total_rows == 0:
            self._emit_metric({
                "metric_type": "completeness",
                "source_system": sys,
                "table_name": tbl,
                "metric_value_num": 100.0,
                "metric_value": "100.00%",
                "status": "PASS",
            }, partition_date=partition_date)
            return

        cols_to_check = resolve_col_list(cur_df, tbl_conf.get("columns", []))
        if not cols_to_check:
            logger.warning(f"No columns configured for completeness in {sys}.{tbl}.")
            return

        # Vectorized null/empty counts
        agg_exprs = [
            F.sum(
                F.when(F.col(c).isNull() | (F.trim(F.col(c).cast(T.StringType())) == ''), 1).otherwise(0)
            ).alias(f"nulls__{c}")
            for c in cols_to_check
        ]

        try:
            agg_row = cur_df.agg(*agg_exprs).first()
        except Exception:
            logger.exception(f"Completeness aggregation failed for {sys}.{tbl}@{partition_date}.")
            for c in cols_to_check:
                self._mark_uncomputable("completeness", sys, tbl, c, partition_date=partition_date)
            return

        thr = float(self.static_thr.get("completeness_pct", 99.0))
        total_nulls = 0

        for c in cols_to_check:
            null_count = int(agg_row[f"nulls__{c}"] or 0)
            total_nulls += null_count
            pct = (1.0 - (float(null_count) / total_rows)) * 100.0
            status = "PASS" if pct >= thr else "FAIL"
            self._emit_metric({
                "metric_type": "completeness",
                "source_system": sys,
                "table_name": tbl,
                "column_name": c,
                "metric_value_num": pct,
                "metric_value": f"{pct:.2f}%",
                "threshold": f">={thr:.2f}%",
                "status": status,
                "alert_flag": "Fail" if status == "FAIL" else None,
            }, partition_date=partition_date)

        overall = (1.0 - (float(total_nulls) / (total_rows * len(cols_to_check)))) * 100.0
        status_overall = "PASS" if overall >= thr else "FAIL"
        self._emit_metric({
            "metric_type": "completeness",
            "source_system": sys,
            "table_name": tbl,
            "column_name": "overall",
            "metric_value_num": overall,
            "metric_value": f"{overall:.2f}%",
            "threshold": f">={thr:.2f}%",
            "status": status_overall,
            "alert_flag": "Fail" if status_overall == "FAIL" else None,
        }, partition_date=partition_date)

    # -------------------------------------------------------------------------
    # Uniqueness
    # -------------------------------------------------------------------------
    def check_uniqueness(
        self, sys: str, tbl: str, tbl_conf: Dict[str, Any], cur_df: Optional[DataFrame], partition_date: str
    ):
        """Percent of distinct key-combinations for the configured join_key(s)."""
        if cur_df is None or df_is_empty(cur_df):
            logger.warning(f"DF None/empty for {sys}.{tbl}@{partition_date}.")
            self._mark_uncomputable("uniqueness", sys, tbl, partition_date=partition_date)
            return

        total = self.df_count_cache.get((sys, tbl, partition_date), 0)
        if total == 0:
            self._emit_metric({
                "metric_type": "uniqueness",
                "source_system": sys,
                "table_name": tbl,
                "metric_value_num": 100.0,
                "metric_value": "100.00%",
                "status": "PASS",
            }, partition_date=partition_date)
            return

        join_key_config = tbl_conf.get("join_key")
        if not join_key_config:
            logger.warning(f"No 'join_key' specified in {sys}.{tbl}.")
            return

        if isinstance(join_key_config, list):
            resolved_keys = resolve_col_list(cur_df, join_key_config)
        else:
            rk = resolve_col_name(cur_df, join_key_config)
            resolved_keys = [rk] if rk else []

        if not resolved_keys:
            logger.error(f"None of keys {join_key_config} found in {sys}.{tbl}.")
            self._mark_uncomputable("uniqueness", sys, tbl, partition_date=partition_date)
            return

        try:
            distinct_count = cur_df.select(*[F.col(k) for k in resolved_keys]).distinct().count()
            pct = (float(distinct_count) / total) * 100.0 if total > 0 else 100.0
            thr = float(self.static_thr.get("uniqueness_pct", 99.9))
            status = "PASS" if pct >= thr else "FAIL"
            self._emit_metric({
                "metric_type": "uniqueness",
                "source_system": sys,
                "table_name": tbl,
                "column_name": ",".join(resolved_keys),
                "metric_value_num": pct,
                "metric_value": f"{pct:.2f}%",
                "threshold": f">={thr:.2f}%",
                "status": status,
                "alert_flag": "Fail" if status == "FAIL" else None,
            }, partition_date=partition_date)
        except Exception:
            logger.exception(f"Uniqueness check failed for {sys}.{tbl}@{partition_date}.")
            self._mark_uncomputable("uniqueness", sys, tbl, ",".join(resolved_keys), partition_date=partition_date)

    # -------------------------------------------------------------------------
    # Range check (vectorized)
    # -------------------------------------------------------------------------
    def check_range(
        self,
        sys: str,
        tbl: str,
        tbl_conf: Dict[str, Any],
        cur_df: Optional[DataFrame],
        partition_date: str
    ):
        """Percent of values out-of-range for each configured numeric column."""
        if cur_df is None or df_is_empty(cur_df):
            logger.warning(f"DF None/empty for {sys}.{tbl}@{partition_date}.")
            self._mark_uncomputable("range_check", sys, tbl, partition_date=partition_date)
            return

        total = self.df_count_cache.get((sys, tbl, partition_date))
        if total is None:
            try:
                total = cur_df.count()
            except Exception:
                logger.exception("Counting DF failed for range check; default total=0.")
                total = 0
            finally:
                self.df_count_cache[(sys, tbl, partition_date)] = total

        if total == 0:
            self._emit_metric({
                "metric_type": "range_check",
                "source_system": sys,
                "table_name": tbl,
                "metric_value_num": 0.0,
                "metric_value": "0.00%",
                "status": "PASS",
            }, partition_date=partition_date)
            return

        ranges: Dict[str, Dict[str, Any]] = (tbl_conf.get("numerical_ranges") or {})
        if not ranges:
            logger.info(f"No numerical ranges for {sys}.{tbl}.")
            return

        resolved_items: List[Tuple[str, str, float, float]] = []
        for colname, bounds in ranges.items():
            resolved = resolve_col_name(cur_df, colname)
            if not resolved:
                logger.warning(f"Range col '{colname}' not found in {sys}.{tbl}.")
                self._mark_uncomputable("range_check", sys, tbl, colname, partition_date=partition_date)
                continue
            try:
                mn = float(bounds.get("min", float("-inf")))
                mx = float(bounds.get("max", float("inf")))
            except Exception:
                logger.exception(f"Invalid bounds for {sys}.{tbl}.{colname}.")
                self._mark_uncomputable("range_check", sys, tbl, colname, partition_date=partition_date)
                continue
            resolved_items.append((colname, resolved, mn, mx))

        if not resolved_items:
            return

        agg_exprs, out_keys = [], []
        for orig_col, resolved_col, mn, mx in resolved_items:
            casted = F.when(F.col(resolved_col).isNotNull(), F.col(resolved_col).cast(T.DoubleType())).otherwise(F.lit(None))
            oor_flag = F.when(F.col(resolved_col).isNotNull() & casted.isNull(), 1) \
                        .when(casted.isNotNull() & ((casted < F.lit(mn)) | (casted > F.lit(mx))), 1) \
                        .otherwise(0)
            alias_name = f"__oor__{resolved_col}"
            agg_exprs.append(F.sum(oor_flag).alias(alias_name))
            out_keys.append((orig_col, alias_name, mn, mx))

        try:
            agg_row = cur_df.agg(*agg_exprs).first()
        except Exception:
            logger.exception(f"Range aggregation failed for {sys}.{tbl}@{partition_date}.")
            for orig_col, _, _, _ in out_keys:
                self._mark_uncomputable("range_check", sys, tbl, orig_col, partition_date=partition_date)
            return

        max_allowed_pct = float(self.static_thr.get("max_out_of_range_pct", 0.0))
        for orig_col, alias_name, mn, mx in out_keys:
            try:
                oor = float(agg_row[alias_name] or 0.0)
                pct = (oor / float(total)) * 100.0 if total > 0 else 0.0
                status = "PASS" if pct <= max_allowed_pct else "FAIL"
                self._emit_metric({
                    "metric_type": "range_check",
                    "source_system": sys,
                    "table_name": tbl,
                    "column_name": orig_col,
                    "metric_value_num": pct,
                    "metric_value": f"{pct:.2f}%",
                    "threshold": f"max_out_of_range_pct<={max_allowed_pct:.2f}%",
                    "status": status,
                    "details": f"[{mn}, {mx}]",
                    "alert_flag": "Fail" if status == "FAIL" else None,
                }, partition_date=partition_date)
            except Exception:
                logger.exception(f"Emit range metric failed for {sys}.{tbl}.{orig_col}@{partition_date}.")
                self._mark_uncomputable("range_check", sys, tbl, orig_col, partition_date=partition_date)

    # -------------------------------------------------------------------------
    # Date parsing validity + Latency
    # -------------------------------------------------------------------------
    def check_date_logic(
        self,
        sys: str,
        tbl: str,
        tbl_conf: Dict[str, Any],
        cur_df: Optional[DataFrame],
        partition_date: str
    ):
        """
        (1) Date-format validity for configured date_columns (by format)
        (2) Data latency (max timestamp/date vs current run date)
        """
        if cur_df is None or df_is_empty(cur_df):
            logger.warning(f"DF None/empty for {sys}.{tbl}@{partition_date}.")
            self._mark_uncomputable("date_format_validity", sys, tbl, partition_date=partition_date)
            self._mark_uncomputable("data_latency", sys, tbl, partition_date=partition_date)
            return

        total = self.df_count_cache.get((sys, tbl, partition_date), 0)
        if total == 0:
            self._emit_metric({"metric_type": "date_format_validity", "source_system": sys, "table_name": tbl,
                               "metric_value_num": 100.0, "metric_value": "100.00%", "status": "PASS"}, partition_date)
            self._emit_metric({"metric_type": "data_latency", "source_system": sys, "table_name": tbl,
                               "metric_value_num": 0.0, "metric_value": "0 days", "status": "PASS"}, partition_date)
            return

        # (1) Date-format validity
        date_cols_config: Dict[str, str] = (tbl_conf.get("date_columns") or {})
        if date_cols_config:
            thr = float(self.static_thr.get("date_parse_min_pct", 95.0))
            for dcol, fmt in date_cols_config.items():
                resolved = resolve_col_name(cur_df, dcol)
                if not resolved:
                    logger.warning(f"Date column '{dcol}' not found in {sys}.{tbl}.")
                    self._mark_uncomputable("date_format_validity", sys, tbl, dcol, partition_date=partition_date)
                    continue
                try:
                    # Count parsable rows directly (no temp cols)
                    parsed = parse_to_timestamp_expr(resolved, fmt)
                    valid_cnt = cur_df.select(F.sum(F.when(parsed.isNotNull(), 1).otherwise(0)).alias("v")).first()["v"] or 0
                    pct_valid = (float(valid_cnt) / float(total)) * 100.0
                    status = "PASS" if pct_valid >= thr else "FAIL"
                    self._emit_metric({
                        "metric_type": "date_format_validity",
                        "source_system": sys,
                        "table_name": tbl,
                        "column_name": dcol,
                        "metric_value_num": pct_valid,
                        "metric_value": f"{pct_valid:.2f}%",
                        "threshold": f">={thr:.2f}%",
                        "status": status,
                        "alert_flag": "Fail" if status == "FAIL" else None,
                    }, partition_date=partition_date)
                except Exception:
                    logger.exception(f"Date-format validity failed for {sys}.{tbl}.{dcol}@{partition_date}.")
                    self._mark_uncomputable("date_format_validity", sys, tbl, dcol, partition_date=partition_date)
        else:
            logger.info(f"No date columns configured for {sys}.{tbl}; skipping date format validity.")

        # (2) Data latency from timestamp_col
        ts_col_config = tbl_conf.get("timestamp_col")
        if not ts_col_config:
            logger.info(f"No timestamp column configured for latency in {sys}.{tbl}.")
            return

        resolved_ts = resolve_col_name(cur_df, ts_col_config)
        if not resolved_ts:
            logger.warning(f"Timestamp column '{ts_col_config}' not found in {sys}.{tbl}.")
            self._mark_uncomputable("data_latency", sys, tbl, ts_col_config, partition_date=partition_date)
            return

        try:
            # If already Date/Timestamp, max works; else parse best-effort
            dtype_map = {f.name: f.dataType for f in cur_df.schema.fields}
            dtype = dtype_map.get(resolved_ts)
            ts_col = F.col(resolved_ts)
            max_col = F.max(ts_col) if isinstance(dtype, (T.TimestampType, T.DateType)) else F.max(F.to_timestamp(ts_col))
            max_ts = cur_df.select(max_col.alias("maxv")).first()["maxv"]

            if max_ts is None:
                self._mark_uncomputable("data_latency", sys, tbl, resolved_ts, partition_date=partition_date)
                return

            if isinstance(max_ts, datetime):
                age_days = (self.curr_date_obj - max_ts.date()).days
            elif isinstance(max_ts, date):
                age_days = (self.curr_date_obj - max_ts).days
            else:
                try:
                    parsed_dt = datetime.fromisoformat(str(max_ts).replace("Z", "+00:00"))
                    age_days = (self.curr_date_obj - parsed_dt.date()).days
                except Exception:
                    self._mark_uncomputable("data_latency", sys, tbl, resolved_ts, partition_date=partition_date)
                    return

            max_age_threshold = int(self.static_thr.get("latency_days", 1))
            status = "PASS" if age_days <= max_age_threshold else "FAIL"
            self._emit_metric({
                "metric_type": "data_latency",
                "source_system": sys,
                "table_name": tbl,
                "column_name": resolved_ts,
                "metric_value_num": float(age_days),
                "metric_value": f"{age_days} days",
                "threshold": f"<={max_age_threshold} days",
                "status": status,
                "alert_flag": "Fail" if status == "FAIL" else None,
            }, partition_date=partition_date)
        except Exception:
            logger.exception(f"Data latency failed for {sys}.{tbl}.{resolved_ts}@{partition_date}.")
            self._mark_uncomputable("data_latency", sys, tbl, resolved_ts, partition_date=partition_date)

    # -------------------------------------------------------------------------
    # Join consistency
    # -------------------------------------------------------------------------
    def check_join_consistency(
        self,
        sys: str,
        tbl: str,
        tbl_conf: Dict[str, Any],
        cur_df: Optional[DataFrame],
        partition_date: str,
        table_run_cache: Dict[Tuple[str, str], Optional[DataFrame]],
    ):
        """
        For each configured join, measure % of DISTINCT left keys missing in the right table (same partition_date).
        Threshold interpreted as required consistency pct; we compute allowed missing pct = 100 - threshold.
        """
        joins = tbl_conf.get("joins") or []
        if not joins:
            return

        for j in joins:
            other_sys = j.get("other_system")
            other_tbl = j.get("other_table")
            left_key_config = j.get("left_key")
            right_key_config = j.get("right_key")

            if not all([other_sys, other_tbl, left_key_config, right_key_config]):
                logger.warning(f"Incomplete join configuration for {sys}.{tbl}: {j}. Skipping.")
                continue

            if cur_df is None or df_is_empty(cur_df):
                logger.warning(f"DF None/empty for {sys}.{tbl}@{partition_date}.")
                self._mark_uncomputable("join_consistency", sys, tbl, left_key_config, partition_date=partition_date)
                continue

            # Other DF from cache or on-demand
            other_df = table_run_cache.get((other_sys, other_tbl))
            if other_df is None:
                other_conf = self._get_table_conf(other_sys, other_tbl)
                if other_conf:
                    other_df = self._load_and_cache_df(other_sys, other_tbl, partition_date, tbl_conf=other_conf)
                    table_run_cache[(other_sys, other_tbl)] = other_df
                else:
                    logger.error(f"Config for join target '{other_sys}.{other_tbl}' not found.")
                    self._mark_uncomputable("join_consistency", sys, tbl, left_key_config, partition_date=partition_date)
                    continue

            if other_df is None or df_is_empty(other_df):
                logger.warning(f"Target DF '{other_sys}.{other_tbl}'@{partition_date} missing/empty.")
                self._mark_uncomputable("join_consistency", sys, tbl, left_key_config, partition_date=partition_date)
                continue

            lk_resolved = resolve_col_name(cur_df, left_key_config)
            rk_resolved = resolve_col_name(other_df, right_key_config)
            if not lk_resolved or not rk_resolved:
                logger.error(f"Join key resolution failed: {sys}.{tbl}.{left_key_config} or {other_sys}.{other_tbl}.{right_key_config}.")
                self._mark_uncomputable("join_consistency", sys, tbl, left_key_config, partition_date=partition_date)
                continue

            try:
                total_left = self.df_count_cache.get((sys, tbl, partition_date), 0)
                if total_left == 0:
                    self._emit_metric({
                        "metric_type": "join_consistency",
                        "source_system": sys,
                        "table_name": tbl,
                        "column_name": left_key_config,
                        "metric_value_num": 0.0,
                        "metric_value": "0.00%",
                        "status": "PASS",
                    }, partition_date=partition_date)
                    continue

                # Distinct keys
                right_keys_df = other_df.select(F.col(rk_resolved).alias(rk_resolved)).distinct()
                left_keys_df  =  cur_df.select(F.col(lk_resolved).alias(lk_resolved)).distinct()

                # Left-anti: keys present on left but missing on right
                missing_keys_df = left_keys_df.join(
                    right_keys_df,
                    left_keys_df[lk_resolved] == right_keys_df[rk_resolved],
                    "left_anti",
                )
                missing_count = missing_keys_df.count()
                distinct_left_key_count = left_keys_df.count()
                pct_missing = (float(missing_count) / distinct_left_key_count) * 100.0 if distinct_left_key_count > 0 else 0.0

                req_consistency = float(self.static_thr.get("join_consistency_pct", 95.0))
                allowed_missing_pct = 100.0 - req_consistency
                status = "PASS" if pct_missing <= allowed_missing_pct else "FAIL"

                self._emit_metric({
                    "metric_type": "join_consistency",
                    "source_system": sys,
                    "table_name": tbl,
                    "column_name": left_key_config,
                    "metric_value_num": pct_missing,
                    "metric_value": f"{pct_missing:.2f}%",
                    "threshold": f"allowed_missing_pct<={allowed_missing_pct:.2f}% (from {req_consistency:.2f}% consistency)",
                    "status": status,
                    "alert_flag": "Fail" if status == "FAIL" else None,
                    "reference_value": float(distinct_left_key_count),
                }, partition_date=partition_date)
            except Exception:
                logger.exception(
                    f"Join consistency failed for {sys}.{tbl} ({left_key_config}) vs {other_sys}.{other_tbl} ({right_key_config})@{partition_date}."
                )
                self._mark_uncomputable("join_consistency", sys, tbl, left_key_config, partition_date=partition_date)

# -------------------------------------------------------------------------
    # Cross-system key consistency
    # -------------------------------------------------------------------------
    def check_cross_system_key_consistency(
        self,
        mapping: Dict[str, Any],
        table_run_cache: Dict[Tuple[str, str], Optional[DataFrame]],
        partition_date: str
    ):
        """
        Ensure keys in (sf.st.sk) exist in (tf.tt.tk). Optionally tokenizes both sides.
        - Consistency threshold = static_thr['cross_system_key_consistency_pct'] (default 90%)
        - We compute pct_missing on DISTINCT source keys using left-anti join.
        """
        sf = mapping.get("source_system_from")
        st = mapping.get("table_from")
        sk_config = mapping.get("key_from")

        tf = mapping.get("source_system_to")
        tt = mapping.get("table_to")
        tk_config = mapping.get("key_to")

        tokenize = bool(mapping.get("tokenize", False))

        if not all([sf, st, sk_config, tf, tt, tk_config]):
            logger.warning(f"Incomplete cross-system mapping config: {mapping}. Skipping.")
            return

        # Load/cached A
        df_a = table_run_cache.get((sf, st))
        if df_a is None:
            conf_a = self._get_table_conf(sf, st)
            if conf_a:
                df_a = self._load_and_cache_df(sf, st, partition_date, tbl_conf=conf_a)
                table_run_cache[(sf, st)] = df_a
            else:
                logger.error(f"Config for source table '{sf}.{st}' not found.")
                self._mark_uncomputable("cross_system_consistency", sf, st, sk_config, partition_date=partition_date)
                return

        # Load/cached B
        df_b = table_run_cache.get((tf, tt))
        if df_b is None:
            conf_b = self._get_table_conf(tf, tt)
            if conf_b:
                df_b = self._load_and_cache_df(tf, tt, partition_date, tbl_conf=conf_b)
                table_run_cache[(tf, tt)] = df_b
            else:
                logger.error(f"Config for target table '{tf}.{tt}' not found.")
                self._mark_uncomputable("cross_system_consistency", sf, st, sk_config, partition_date=partition_date)
                return

        if df_a is None or df_is_empty(df_a) or df_b is None or df_is_empty(df_b):
            logger.warning(f"One/both DFs missing/empty for cross-system check: {sf}.{st}, {tf}.{tt} @ {partition_date}.")
            self._mark_uncomputable("cross_system_consistency", sf, st, sk_config, partition_date=partition_date)
            return

        # Resolve key columns
        resolved_sk = resolve_col_name(df_a, sk_config)
        resolved_tk = resolve_col_name(df_b, tk_config)
        if not resolved_sk or not resolved_tk:
            logger.error(f"Key resolution failed: '{sk_config}' in {sf}.{st} or '{tk_config}' in {tf}.{tt}.")
            self._mark_uncomputable("cross_system_consistency", sf, st, sk_config, partition_date=partition_date)
            return

        # Optional tokenization (produces _harm_key_a / _harm_key_b)
        if tokenize:
            df_a = tokenize_col_if_needed(df_a, [resolved_sk], token_col_name="_harm_key_a")
            df_b = tokenize_col_if_needed(df_b, [resolved_tk], token_col_name="_harm_key_b")
            lk_resolved, rk_resolved = "_harm_key_a", "_harm_key_b"
            if lk_resolved not in df_a.columns or rk_resolved not in df_b.columns:
                logger.error("Tokenization failed to produce harmonized key columns.")
                self._mark_uncomputable("cross_system_consistency", sf, st, sk_config, partition_date=partition_date)
                return
        else:
            lk_resolved, rk_resolved = resolved_sk, resolved_tk

        try:
            src_keys = df_a.select(F.col(lk_resolved).alias(lk_resolved)).distinct()
            total_src = src_keys.count()
            if total_src == 0:
                self._emit_metric({
                    "metric_type": "cross_system_consistency",
                    "source_system": sf,
                    "table_name": st,
                    "column_name": sk_config,
                    "metric_value_num": 0.0,
                    "metric_value": "0.00%",
                    "status": "PASS",
                }, partition_date=partition_date)
                return

            tgt_keys = df_b.select(F.col(rk_resolved).alias(rk_resolved)).distinct()

            # left_anti: keys present in source but missing in target
            missing = src_keys.join(tgt_keys, src_keys[lk_resolved] == tgt_keys[rk_resolved], "left_anti")
            miss_cnt = missing.count()
            pct_missing = (float(miss_cnt) / float(total_src)) * 100.0

            required_pct = float(self.static_thr.get("cross_system_key_consistency_pct", 90.0))
            allowed_missing = 100.0 - required_pct
            status = "PASS" if pct_missing <= allowed_missing else "FAIL"

            self._emit_metric({
                "metric_type": "cross_system_consistency",
                "source_system": sf,
                "table_name": st,
                "column_name": sk_config,
                "metric_value_num": pct_missing,
                "metric_value": f"{pct_missing:.2f}%",
                "threshold": f"allowed_missing_pct<={allowed_missing:.2f}% (from {required_pct:.2f}% consistency)",
                "status": status,
                "alert_flag": "Fail" if status == "FAIL" else None,
                "reference_value": float(total_src),
            }, partition_date=partition_date)

        except Exception:
            logger.exception(f"Cross-system key consistency failed for mapping {mapping} @ {partition_date}.")
            self._mark_uncomputable("cross_system_consistency", sf, st, sk_config, partition_date=partition_date)

    # -------------------------------------------------------------------------
    # Distribution drift (numeric+categorical) — PSI/CSI + Wazirstrian + Genshin Shannon
    # -------------------------------------------------------------------------
    def check_distribution_drift(
        self,
        sys: str,
        tbl: str,
        tbl_conf: Dict[str, Any],
        cur_df: Optional[DataFrame],
        partition_date: str
    ):
        """
        Compute distribution drift for columns configured as:
          - numeric_drift_cols: PSI, Wazirstrian (Wasserstein), and Genshin Shannon ratio
          - categorical_drift_cols (or inferred from 'columns'): CSI (+ Genshin Shannon on category pmfs)
        Uses helper fns that NEVER create slice-keys, so no 'unhashable type: slice' errors.
        """
        if cur_df is None or df_is_empty(cur_df):
            logger.warning(f"DF None/empty for {sys}.{tbl}@{partition_date}. Skipping distribution drift.")
            return

        # Reference partition (recent, by config)
        ref_date = self._get_ref_partition_date(tbl_conf.get("full_table_path"), tbl_conf.get("partition_col"))
        ref_df = None
        if ref_date:
            ref_df = self._load_and_cache_df(sys, tbl, ref_date, tbl_conf=tbl_conf)
            if ref_df is None or df_is_empty(ref_df):
                logger.warning(f"Reference DF empty for {sys}.{tbl}@{ref_date}. Proceeding with PSI=None.")
                ref_df = None

        # Thresholds
        psi_max = float(self.static_thr.get("numeric_psi_max", 0.20))
        csi_max = float(self.static_thr.get("categorical_csi_max", 50.0))
        waz_max = float(self.static_thr.get("wazirstrian_wasserstein_max", 1.0))  # your custom threshold
        genshin_max = float(self.static_thr.get("genshin_shannon_max", 1.15))      # your custom threshold (ratio cap)

        bins = int(self.static_thr.get("drift_bins", DRIFT_BINS))
        sample_size = int(self.static_thr.get("drift_sample_size", DRIFT_SAMPLE_SIZE))
        max_card = int(self.static_thr.get("max_categorical_cardinality", MAX_CATEGORICAL_CARDINALITY))

        # --- Numeric drift ---
        num_cols_cfg = list(tbl_conf.get("numeric_drift_cols", []))
        # Allow per-column overrides embedded in columns config (role=='numeric_feature')
        for c in (tbl_conf.get("columns") or []):
            if isinstance(c, dict) and c.get("role") == "numeric_feature":
                name = c.get("name")
                if name and name not in num_cols_cfg:
                    num_cols_cfg.append(name)

        for colname in num_cols_cfg:
            resolved = resolve_col_name(cur_df, colname)
            if not resolved:
                logger.warning(f"Numeric drift column '{colname}' not in {sys}.{tbl}.")
                continue

            psi, edges, cur_counts, ref_counts = compute_numeric_psi_and_buckets(
                resolved, cur_df, ref_df, bins=bins, sample_size=sample_size
            )

            # Wazirstrian (Wasserstein) only if we have both histograms
            waz = None
            genshin = None
            if edges and cur_counts is not None and ref_counts is not None:
                waz = _wasserstein_from_buckets(edges, cur_counts, ref_counts)

                # Genshin Shannon ratio = H(cur) / max(EPS, H(ref))
                def _entropy_from_counts(cnts: List[int]) -> float:
                    s = float(sum(cnts))
                    if s <= 0: return 0.0
                    H = 0.0
                    for v in cnts:
                        p = float(v) / s
                        if p > 0:
                            H -= p * math.log(p + EPSILON)
                    return H
                Hc = _entropy_from_counts(cur_counts)
                Hr = _entropy_from_counts(ref_counts)
                genshin = (Hc / max(EPSILON, Hr)) if Hr > 0 else (float('inf') if Hc > 0 else 1.0)

            # Emit PSI (informational unless you treat as hard threshold)
            if psi is not None:
                status = "PASS" if psi <= psi_max else "FAIL"
                self._emit_metric({
                    "metric_type": "numeric_drift_psi",
                    "source_system": sys,
                    "table_name": tbl,
                    "column_name": colname,
                    "metric_value_num": float(psi),
                    "metric_value": f"{psi:.4f}",
                    "threshold": f"<= {psi_max:.2f}",
                    "status": status,
                    "alert_flag": "Fail" if status == "FAIL" else None,
                    "reference_value": ref_date,
                }, partition_date=partition_date)

            # Emit Wazirstrian (Wasserstein)
            if waz is not None:
                status = "PASS" if waz <= waz_max else "FAIL"
                self._emit_metric({
                    "metric_type": "numeric_drift_wazirstrian",
                    "source_system": sys,
                    "table_name": tbl,
                    "column_name": colname,
                    "metric_value_num": float(waz),
                    "metric_value": f"{waz:.6f}",
                    "threshold": f"<= {waz_max:.2f}",
                    "status": status,
                    "alert_flag": "Fail" if status == "FAIL" else None,
                    "reference_value": ref_date,
                }, partition_date=partition_date)

            # Emit Genshin Shannon ratio
            if genshin is not None and math.isfinite(genshin):
                status = "PASS" if genshin <= genshin_max else "FAIL"
                self._emit_metric({
                    "metric_type": "numeric_drift_genshin_shannon",
                    "source_system": sys,
                    "table_name": tbl,
                    "column_name": colname,
                    "metric_value_num": float(genshin),
                    "metric_value": f"{genshin:.4f}",
                    "threshold": f"<= {genshin_max:.2f}",
                    "status": status,
                    "alert_flag": "Fail" if status == "FAIL" else None,
                    "reference_value": ref_date,
                }, partition_date=partition_date)
            elif genshin is None and ref_df is None:
                # No reference -> explicitly mark uncomputable for ratio-based metric
                self._mark_uncomputable("numeric_drift_genshin_shannon", sys, tbl, colname, partition_date=partition_date)

        # --- Categorical drift ---
        # If you have an explicit cfg, use it; otherwise infer from 'columns' that are not numeric_feature
        cat_cols_cfg = list(tbl_conf.get("categorical_drift_cols", []))
        if not cat_cols_cfg:
            for c in (tbl_conf.get("columns") or []):
                if isinstance(c, str):
                    # plain column names — treat as candidates if not already in numeric list
                    if c not in num_cols_cfg and c not in cat_cols_cfg:
                        cat_cols_cfg.append(c)
                elif isinstance(c, dict):
                    nm = c.get("name")
                    role = c.get("role")
                    if nm and role != "numeric_feature" and nm not in num_cols_cfg and nm not in cat_cols_cfg:
                        cat_cols_cfg.append(nm)

        for colname in cat_cols_cfg:
            resolved = resolve_col_name(cur_df, colname)
            if not resolved:
                continue

            csi, cur_map, ref_map = compute_categorical_csi_and_buckets(
                resolved, cur_df, ref_df, sample_size=sample_size, max_cardinality=max_card
            )

            # CSI (Chi-square-like)
            if csi is not None:
                status = "PASS" if csi <= csi_max else "FAIL"
                self._emit_metric({
                    "metric_type": "categorical_drift_csi",
                    "source_system": sys,
                    "table_name": tbl,
                    "column_name": colname,
                    "metric_value_num": float(csi),
                    "metric_value": f"{csi:.2f}",
                    "threshold": f"<= {csi_max:.2f}",
                    "status": status,
                    "alert_flag": "Fail" if status == "FAIL" else None,
                    "reference_value": ref_date,
                }, partition_date=partition_date)

            # Genshin Shannon ratio on categorical pmfs
            genshin = None
            if cur_map and ref_map:
                def _entropy_from_map(m: Dict[str, int]) -> float:
                    s = float(sum(m.values()))
                    if s <= 0: return 0.0
                    H = 0.0
                    for v in m.values():
                        p = float(v) / s
                        if p > 0:
                            H -= p * math.log(p + EPSILON)
                    return H
                Hc = _entropy_from_map(cur_map)
                Hr = _entropy_from_map(ref_map)
                genshin = (Hc / max(EPSILON, Hr)) if Hr > 0 else (float('inf') if Hc > 0 else 1.0)

            if genshin is not None and math.isfinite(genshin):
                status = "PASS" if genshin <= genshin_max else "FAIL"
                self._emit_metric({
                    "metric_type": "categorical_drift_genshin_shannon",
                    "source_system": sys,
                    "table_name": tbl,
                    "column_name": colname,
                    "metric_value_num": float(genshin),
                    "metric_value": f"{genshin:.4f}",
                    "threshold": f"<= {genshin_max:.2f}",
                    "status": status,
                    "alert_flag": "Fail" if status == "FAIL" else None,
                    "reference_value": ref_date,
                }, partition_date=partition_date)
            elif genshin is None and ref_df is None:
                self._mark_uncomputable("categorical_drift_genshin_shannon", sys, tbl, colname, partition_date=partition_date)


    def finalize_and_write(self):
        """
        Sanitize -> normalize -> write the three canonical outputs:
          1) DQ_METRIC_HISTORY (full audit)
          2) DQ_MONITORING_REPORT (FAIL / alert_flag=Fail)
          3) DQ_INTERNAL_ERRORS (explicit error metric types + UNCOMPUTABLE/ERROR)
        Also performs cache cleanup and clears in-memory buffers.
        """
        logger.info("Sanitizing metrics and preparing outputs (finalize_and_write).")

        try:
            # 1) Sanitize collected metrics
            try:
                sanitized_metrics = sanitize_all_metrics_return_new(self.all_metrics or [])
            except Exception:
                logger.exception("sanitize_all_metrics_return_new failed; falling back to raw metrics.")
                sanitized_metrics = list(self.all_metrics or [])

            # 2) Normalize sanitized metrics
            try:
                normalized_records = [self._normalize_metric_record(rec) for rec in sanitized_metrics]
            except Exception:
                logger.exception("Failed to normalize sanitized metrics; using sanitized output as-is.")
                normalized_records = sanitized_metrics

            # 3) Build Spark DataFrame for history (canonical audit)
            metrics_history_df = rows_to_df(self.spark, normalized_records, schema=metrics_history_schema)
            if metrics_history_df is None or df_is_empty(metrics_history_df):
                logger.warning("metrics_history_df is None/empty after rows_to_df. Aborting writes.")
                return

            # 4) Write full metrics history (audit)
            try:
                history_ds = OUTPUT_DATASETS.get("metrics_history_df", "DQ_METRIC_HISTORY")
                write_dataset(history_ds, metrics_history_df, mode="append")  # accepts legacy append=
                logger.info(f"Wrote full metrics history to '{history_ds}'. Rows: {len(normalized_records)}")
            except Exception:
                logger.exception("Failed to write metrics history dataset.")

            # 5) Monitoring report: failed/alerted metrics
            try:
                monitoring_report_df = metrics_history_df.filter(
                    (F.col("status") == "FAIL") | (F.col("alert_flag") == "Fail")
                )
            except Exception:
                logger.exception("Error filtering metrics for monitoring report.")
                monitoring_report_df = None

            try:
                if monitoring_report_df is not None and not df_is_empty(monitoring_report_df):
                    mon_ds = OUTPUT_DATASETS.get("monitoring_report_df", "DQ_MONITORING_REPORT")
                    write_dataset(mon_ds, monitoring_report_df, mode="append")
                    logger.info(f"Monitoring report written to '{mon_ds}'.")
                else:
                    logger.info("No failed/alerted metrics to write to monitoring report.")
            except Exception:
                logger.exception("Failed to write monitoring report dataset.")

            # 6) Internal errors dataset: explicit error types or problematic status
            try:
                error_types = [
                    "metric_sanitization_failure",
                    "data_load_error",
                    "dq_internal_error",
                    "dq_distribution_internal_error",
                ]
                internal_errors_df = metrics_history_df.filter(
                    F.col("metric_type").isin(error_types)
                    | (F.col("status") == "UNCOMPUTABLE")
                    | (F.col("status") == "ERROR")
                )
            except Exception:
                logger.exception("Error filtering metrics for internal errors dataset.")
                internal_errors_df = None

            try:
                if internal_errors_df is not None and not df_is_empty(internal_errors_df):
                    err_ds = OUTPUT_DATASETS.get("internal_errors", "DQ_INTERNAL_ERRORS")
                    write_dataset(err_ds, internal_errors_df, mode="append")
                    logger.info(f"Internal errors dataset written to '{err_ds}'.")
                else:
                    logger.info("No internal errors to write.")
            except Exception:
                logger.exception("Failed to write internal errors dataset.")

            # 7) Cleanup caches
            logger.debug("Unpersisting cached DataFrames and clearing caches.")
            try:
                for key in list(getattr(self, "df_cache", {}).keys()):
                    df = self.df_cache.get(key)
                    if df is not None:
                        try:
                            df.unpersist()
                            logger.debug(f"Unpersisted DataFrame for cache key: {key}")
                        except Exception as e:
                            logger.debug(f"Error unpersisting DataFrame for key {key}: {e}")
                    try:
                        del self.df_cache[key]
                    except Exception:
                        pass
                    if key in getattr(self, "df_count_cache", {}):
                        try:
                            del self.df_count_cache[key]
                        except Exception:
                            pass
            except Exception:
                logger.exception("Error during df_cache cleanup.")

            if getattr(self, "metrics_history_df_cache", None) is not None:
                try:
                    self.metrics_history_df_cache.unpersist()
                    logger.debug("Unpersisted metrics history DataFrame cache.")
                except Exception as e:
                    logger.debug(f"Error unpersisting metrics history DataFrame: {e}")
                self.metrics_history_df_cache = None

            # Finally clear in-memory metrics buffer
            self.all_metrics = []

        except Exception as e:
            logger.exception(f"finalize_and_write encountered an unexpected exception: {e}")




    def run_all(self):
        """Orchestrates all data quality checks for the current run date."""
        logger.info(f"Starting DQ Monitor run for date: {self.curr_date_str}")

        # cache for this run (current partition only)
        table_run_cache: Dict[Tuple[str, str], Optional[DataFrame]] = {}

        # --- Phase 1: Load current/effective partition data ---
        logger.info("Phase 1: Loading effective partitions for all configured tables.")
        for sys_name, sys_conf in (self.table_config or {}).items():
            for tbl_name, tbl_conf in (sys_conf.get("tables") or {}).items():
                try:
                    logger.debug(f"Loading effective partition for {sys_name}.{tbl_name}@{self.curr_date_str}.")
                    cur_df = self._load_and_cache_df(sys_name, tbl_name, self.curr_date_str, tbl_conf=tbl_conf)
                    table_run_cache[(sys_name, tbl_name)] = cur_df
                except Exception:
                    logger.exception("Load failed for %s.%s@%s", sys_name, tbl_name, self.curr_date_str)
                    table_run_cache[(sys_name, tbl_name)] = None
                    self._emit_metric({
                        "metric_type": "data_load_error",
                        "source_system": sys_name,
                        "table_name": tbl_name,
                        "metric_value": f"Load failure for {sys_name}.{tbl_name}@{self.curr_date_str}",
                        "status": "ERROR",
                        "alert_flag": "Fail"
                    }, partition_date=self.curr_date_str)

        # --- Phase 2: Run table-level checks ---
        logger.info("Phase 2: Running data quality checks.")
        for sys_name, sys_conf in (self.table_config or {}).items():
            for tbl_name, tbl_conf in (sys_conf.get("tables") or {}).items():
                part = self.curr_date_str
                logger.info(f"Running checks for {sys_name}.{tbl_name}@{part}")
                cur_df = table_run_cache.get((sys_name, tbl_name))

                try:
                    self.check_row_count(sys_name, tbl_name, tbl_conf, cur_df, part)
                    self.check_completeness(sys_name, tbl_name, tbl_conf, cur_df, part)
                    self.check_uniqueness(sys_name, tbl_name, tbl_conf, cur_df, part)
                    self.check_distribution_drift(sys_name, tbl_name, tbl_conf, cur_df, part)

                    # If you really have these implemented, keep them. Otherwise, guard or remove.
                    if hasattr(self, "check_range"):
                        self.check_range(sys_name, tbl_name, tbl_conf, cur_df, part)
                    if hasattr(self, "check_date_logic"):
                        self.check_date_logic(sys_name, tbl_name, tbl_conf, cur_df, part)

                    # Do NOT run cross-system inside this per-table loop. It runs in Phase 3.
                    # Remove the typo/alias call that used to crash:
                    # self.check_cross_system_join_consistency(...)
                    # self.check_jocheck_cross_system_key_consistencyin_consistency(...)

                except Exception as e:
                    logger.exception("Checks failed for %s.%s@%s", sys_name, tbl_name, part)
                    self._emit_metric({
                        "metric_type": "dq_internal_error",
                        "source_system": sys_name,
                        "table_name": tbl_name,
                        "metric_value": f"Unhandled exception during checks: {e}",
                        "status": "UNCOMPUTABLE",
                        "alert_flag": "Fail"
                    }, partition_date=part)

        # --- Phase 3: Cross-system consistency checks ---
        logger.info("Phase 3: Running cross-system consistency checks.")
        for mapping in (self.cross_conf or []):
            try:
                self.check_cross_system_key_consistency(mapping, table_run_cache, self.curr_date_str)
            except Exception as e:
                logger.exception("Cross-system check failed for mapping: %s", mapping)
                sf = mapping.get("source_system_from", "unknown")
                st = mapping.get("table_from", "unknown")
                sk = mapping.get("key_from", "unknown")
                self._emit_metric({
                    "metric_type": "cross_system_consistency_internal_error",
                    "source_system": sf,
                    "table_name": st,
                    "column_name": sk,
                    "metric_value": f"Unhandled exception during cross-system check: {e}",
                    "status": "UNCOMPUTABLE",
                    "alert_flag": "Fail"
                }, partition_date=self.curr_date_str)

        # --- Phase 4: Finalize and Write Results ---
        logger.info("Phase 4: Finalizing and writing results.")
        self.finalize_and_write()
        logger.info("DQ Monitor run complete.")

In [0]:
# -----------------------
# Main Execution Block
# -----------------------
def main():
    """Entry point for the DQ monitoring script."""
    logger.info("Starting DQ Monitor script execution.")

    # Fetch table configurations dynamically or from a fixed source
    # Ensure get_table_config returns a valid structure.
    table_config = get_table_config(REGION)

    if not table_config:
        logger.error("Failed to load table configuration. Exiting.")
        # Optionally raise an exception or exit with a non-zero code
        return

    try:
        # Initialize the DQMonitor with all necessary configurations and Spark session
        monitor = DQMonitor(
            spark_session=spark,
            table_config=table_config,
            cross_conf=cross_system_key_harmonization,
            curr_date_obj=CURR_DATE,
            static_thr=dq_static_thresholds
        )

        # Execute all defined checks
        monitor.run_all()

    except Exception as e:
        logger.exception("An unhandled exception occurred during DQMonitor execution.")
        # Attempt to write a final error metric if possible
        try:
            # Create a list of error metrics to be potentially written
            error_metrics = [{
                "run_id": RUN_ID,
                "run_ts": RUN_TS,
                "partition_date": CURR_DATE_STR,
                "metric_type": "dq_script_failure",
                "source_system": "N/A",
                "table_name": "N/A",
                "column_name": "N/A",
                "metric_value": f"Script failed with unhandled exception: {e}",
                "status": "UNCOMPUTABLE",
                "alert_flag": "Fail",
                "country": REGION
            }]
            # Try to write these errors to the internal errors dataset
            error_df = rows_to_df(spark, error_metrics, schema=metrics_history_schema)
            write_dataset(OUTPUT_DATASETS["internal_errors"], error_df, mode="append")
        except Exception as final_err:
            logger.exception("Failed to write final error metric to internal errors dataset.")

    finally:
        # Ensure Spark session is stopped if it was explicitly started by start_spark_session
        # Or just rely on Dataiku's management if running within its environment.
        # For standalone scripts, explicit stop might be needed.
        logger.info("DQ Monitor script finished.")
        # spark.stop() # Uncomment if running in an environment where explicit stop is necessary


if __name__ == "__main__":
    main()