In [0]:
%pylab inline

In [0]:
def check_distribution_drift(self, sys: str, tbl: str, tbl_conf: Dict[str, Any], cur_df: Optional[DataFrame], partition_date: str):
        """
        Enhanced distribution drift check (driver-side summaries).
        - Numeric primary metrics: Wasserstein (bucket-based) + quantile deltas (q10/q50/q90)
          - PSI emitted as diagnostic if available (from compute_numeric_psi_and_buckets)
        - Categorical primary metrics: Jensen-Shannon (top-K) + entropy delta + top-K mass & churn
          - Chi2 emitted as diagnostic if available (from compute_categorical_csi_and_buckets)
        - Falls back to historical psi_buckets / csi_buckets records when no reference partition exists.
        Emits metrics conforming to metrics_history_schema via self._emit_metric(...)
        """
        # Local small helpers (use self._... if already present)
        def _counts_to_prob_vector(counts, eps=EPSILON):
            total = float(sum(counts))
            if total <= 0:
                n = max(1, len(counts))
                return [1.0 / n for _ in counts]
            return [max(float(c) / total, eps) for c in counts]

        def _wasserstein_from_buckets(edges, cur_counts, ref_counts):
            try:
                if not edges or not cur_counts or not ref_counts:
                    return None
                n_bins = len(edges) - 1
                if len(cur_counts) != n_bins or len(ref_counts) != n_bins:
                    self.logger.warning("Bucket length mismatch for wasserstein.")
                    return None
                cur_prob = _counts_to_prob_vector(cur_counts)
                ref_prob = _counts_to_prob_vector(ref_counts)
                cdf_cur = [sum(cur_prob[:i+1]) for i in range(n_bins)]
                cdf_ref = [sum(ref_prob[:i+1]) for i in range(n_bins)]
                w = 0.0
                for i in range(n_bins):
                    width = float(edges[i+1]) - float(edges[i])
                    diff = abs(cdf_cur[i] - cdf_ref[i])
                    w += diff * width
                return float(w)
            except Exception:
                self.logger.exception("Wasserstein computation error.")
                return None

        def _jensen_shannon_from_maps(cur_map, ref_map, top_k=None, eps=EPSILON):
            try:
                cur_map = cur_map or {}
                ref_map = ref_map or {}
                all_keys = set(cur_map.keys()) | set(ref_map.keys())
                if top_k and len(all_keys) > top_k:
                    combined = {k: cur_map.get(k, 0) + ref_map.get(k, 0) for k in all_keys}
                    top_keys = set(sorted(combined.keys(), key=lambda k: combined[k], reverse=True)[:top_k])
                    def reduce_map(m):
                        out = {}
                        other = 0
                        for k,v in m.items():
                            if k in top_keys:
                                out[k] = v
                            else:
                                other += v
                        if other > 0:
                            out["__OTHER__"] = other
                        return out
                    cur_map = reduce_map(cur_map)
                    ref_map = reduce_map(ref_map)
                    all_keys = set(cur_map.keys()) | set(ref_map.keys())
                keys = sorted(all_keys)
                cur_counts = [float(cur_map.get(k, 0)) for k in keys]
                ref_counts = [float(ref_map.get(k, 0)) for k in keys]
                cur_prob = _counts_to_prob_vector(cur_counts, eps)
                ref_prob = _counts_to_prob_vector(ref_counts, eps)
                m = [(p + q) / 2.0 for p, q in zip(cur_prob, ref_prob)]
                def kl(p_vec, q_vec):
                    s = 0.0
                    for p_val, q_val in zip(p_vec, q_vec):
                        if p_val <= 0:
                            continue
                        s += p_val * math.log(p_val / q_val)
                    return s
                js = 0.5 * (kl(cur_prob, m) + kl(ref_prob, m))
                return float(js)
            except Exception:
                self.logger.exception("Jensen-Shannon computation error.")
                return None

        def _entropy_from_map(m, eps=EPSILON):
            counts = [float(v) for v in (m or {}).values()]
            total = sum(counts)
            if total <= 0:
                return 0.0
            probs = [max(c / total, eps) for c in counts]
            return float(-sum(p * math.log(p) for p in probs))

        def _topk_churn_and_mass_change(cur_map, ref_map, k=10):
            cur_map = cur_map or {}
            ref_map = ref_map or {}
            topk_cur = [k for k,_ in sorted(cur_map.items(), key=lambda kv: kv[1], reverse=True)[:k]]
            topk_ref = [k for k,_ in sorted(ref_map.items(), key=lambda kv: kv[1], reverse=True)[:k]]
            removed = len([x for x in topk_ref if x not in topk_cur])
            added = len([x for x in topk_cur if x not in topk_ref])
            def topk_share(m, keys):
                total = float(sum(m.values()) or 0.0)
                if total == 0.0:
                    return 0.0
                return sum(m.get(k, 0) for k in keys) / total
            cur_share = topk_share(cur_map, topk_cur)
            ref_share = topk_share(ref_map, topk_ref)
            mass_delta = cur_share - ref_share
            churn = (removed + added) / float(k)
            return float(mass_delta), removed, added, float(churn)

        # --------------------
        # Begin function logic
        # --------------------
        full_path = tbl_conf.get("full_table_path")
        if not full_path:
            self.logger.error(f"Missing 'full_table_path' for {sys}.{tbl}. Cannot perform drift checks.")
            return

        if cur_df is None or df_is_empty(cur_df):
            self.logger.warning(f"Current DataFrame empty for {sys}.{tbl}@{partition_date}. Skipping distribution drift.")
            # mark uncomputable for configured features
            for c in (tbl_conf.get("numeric_drift_cols") or []):
                self._mark_uncomputable("psi", sys, tbl, c, partition_date=partition_date)
            for c in (tbl_conf.get("categorical_cols") or []):
                self._mark_uncomputable("csi", sys, tbl, c, partition_date=partition_date)
            return

        # find reference partition and load it if possible
        ref_date = None
        try:
            ref_date = self._get_ref_partition_date(full_path)
        except Exception:
            self.logger.debug("No ref partition date available via _get_ref_partition_date.")
        ref_df = None
        if ref_date:
            ref_df = self._load_and_cache_df(sys, tbl, tbl_conf, ref_date)
            if ref_df is not None and df_is_empty(ref_df):
                self.logger.warning(f"Reference partition '{ref_date}' for {sys}.{tbl} is empty; ignoring reference.")
                ref_df = None

        # load metrics history DF once (your existing logic)
        if not self.metrics_history_loaded:
            mh_dataset_name = OUTPUT_DATASETS.get("metrics_history_df")
            if mh_dataset_name:
                try:
                    self.metrics_history_df_cache = spark.read.table(mh_dataset_name)
                    self.metrics_history_df_cache.cache()
                    # trigger caching safely
                    try:
                        self.metrics_history_df_cache.count()
                    except Exception:
                        pass
                    self.metrics_history_loaded = True
                except Exception:
                    self.logger.exception(f"Failed to load metrics history from '{mh_dataset_name}'.")
                    self.metrics_history_df_cache = None
                    self.metrics_history_loaded = True
            else:
                self.logger.debug("No metrics_history_df configured.")
                self.metrics_history_loaded = True

        mh_df = self.metrics_history_df_cache

        # --------------------
        # Numeric drift
        # --------------------
        numeric_drift_cols = tbl_conf.get("numeric_drift_cols") or []
        for col_config in numeric_drift_cols:
            col_name = col_config if isinstance(col_config, str) else col_config.get("name")
            if not col_name:
                continue
            resolved_col = resolve_col_name(cur_df, col_name)
            if not resolved_col:
                self.logger.warning(f"Numeric column '{col_name}' not found in {sys}.{tbl}.")
                self._mark_uncomputable("psi", sys, tbl, col_name, partition_date=partition_date)
                continue

            # Attempt primary computation using your helper (produces psi & buckets), else fallback
            psi_val = None
            psi_edges = None
            cur_counts = None
            ref_counts = None
            # compute using existing helper if available
            try:
                if "compute_numeric_psi_and_buckets" in globals():
                    psi_val, psi_edges, cur_counts, ref_counts = compute_numeric_psi_and_buckets(
                        resolved_col, cur_df, ref_df,
                        bins=DRIFT_BINS,
                        sample_size=int(self.static_thr.get("drift_sample_size", DRIFT_SAMPLE_SIZE))
                    )
                else:
                    # fallback: try to produce equal-width buckets from ref/current min/max
                    psi_val, psi_edges, cur_counts, ref_counts = None, None, None, None
            except Exception:
                self.logger.exception(f"Error computing numeric buckets for {resolved_col}.")

            # emit bucket snapshot if present (current counts)
            if psi_edges is not None and cur_counts is not None:
                metric_value_buckets = json.dumps({"bucket_edges": psi_edges, "bucket_counts": cur_counts}, default=str, ensure_ascii=False)
                self._emit_metric({
                    "run_id": RUN_ID, "run_ts": RUN_TS, "partition_date": partition_date,
                    "metric_type": "psi_buckets", "source_system": sys, "table_name": tbl, "column_name": resolved_col,
                    "metric_value": metric_value_buckets, "metric_value_num": None, "threshold": None, "reference_value": None,
                    "status": "PASS", "alert_flag": None, "country": tbl_conf.get("country", REGION)
                }, partition_date=partition_date)
            else:
                self._mark_uncomputable("psi_buckets", sys, tbl, resolved_col, partition_date=partition_date)

            # If no reference available, try to recover buckets & ref counts from historical psi_buckets
            if psi_val is None and ref_df is None and mh_df is not None:
                try:
                    hist = mh_df.filter(
                        (F.col("table_name") == tbl) &
                        (F.col("metric_type") == "psi_buckets") &
                        (F.col("column_name") == resolved_col)
                    ).orderBy(F.col("partition_date").desc()).limit(MAX_LOOKBACK_SAMPLES).collect()
                    for r in hist:
                        parsed = _parse_buckets_from_metric_value(r["metric_value"])
                        if not parsed:
                            continue
                        fb_edges = parsed.get("bucket_edges") or parsed.get("edges") or parsed.get("edges")
                        fb_ref_counts = parsed.get("bucket_counts") or parsed.get("counts")
                        if not fb_edges or not fb_ref_counts or len(fb_edges) - 1 != len(fb_ref_counts):
                            self.logger.warning(f"Skipping historical PSI record due to invalid bucket data: {r['partition_date']}")
                            continue
                        try:
                            bucketizer = Bucketizer(splits=[float(x) for x in fb_edges], inputCol=resolved_col, outputCol="_dq_bucket", handleInvalid="skip")
                            cur_proc_for_fb = cur_df.select(F.regexp_replace(F.col(resolved_col), ",", "").cast(DoubleType()).alias(resolved_col)).where(F.col(resolved_col).isNotNull())
                            cur_b = bucketizer.transform(cur_proc_for_fb).groupBy("_dq_bucket").count().withColumnRenamed("count", "cur_cnt")
                            cur_rows_fb = cur_b.orderBy("_dq_bucket").collect()
                            current_counts_fb = [0] * (len(fb_edges) - 1)
                            for rr in cur_rows_fb:
                                idx = int(rr["_dq_bucket"])
                                if 0 <= idx < len(current_counts_fb):
                                    current_counts_fb[idx] = int(rr["cur_cnt"])
                            # compute psi using helper if present, else local
                            if "_compute_psi_from_counts" in dir(self):
                                psi_fb = self._compute_psi_from_counts(current_counts_fb, [int(c) for c in fb_ref_counts], EPSILON)
                            else:
                                psi_fb = _compute_psi_from_counts(current_counts_fb, [int(c) for c in fb_ref_counts], EPSILON)
                            if psi_fb is not None:
                                psi_val = psi_fb
                                metric_value_cur_fb = json.dumps({"bucket_edges": fb_edges, "bucket_counts": current_counts_fb}, default=str, ensure_ascii=False)
                                self._emit_metric({
                                    "run_id": RUN_ID, "run_ts": RUN_TS, "partition_date": partition_date,
                                    "metric_type": "psi_buckets", "source_system": sys, "table_name": tbl, "column_name": resolved_col,
                                    "metric_value": metric_value_cur_fb, "metric_value_num": None, "threshold": None, "reference_value": None,
                                    "status": "PASS", "alert_flag": None, "country": tbl_conf.get("country", REGION)
                                }, partition_date=partition_date)
                                self.logger.info(f"Computed PSI using historical bins from {r['partition_date']}.")
                                break
                        except Exception:
                            self.logger.exception("Error during PSI fallback computation with historical bins.")
                            continue
                except Exception:
                    self.logger.exception("Error loading historical metrics for PSI fallback.")

            # Now compute primary numeric drift metrics: Wasserstein + quantile deltas
            numeric_metrics = {}
            try:
                # If we have edges and ref_counts, compute wasserstein
                if psi_edges is not None and cur_counts is not None and ref_counts is not None:
                    numeric_metrics["wasserstein"] = _wasserstein_from_buckets(psi_edges, cur_counts, ref_counts)
                # also compute approximate quantiles (10/50/90)
                try:
                    probs = [0.1, 0.5, 0.9]
                    if not df_is_empty(cur_df):
                        cur_q = cur_df.approxQuantile(resolved_col, probs, 0.01)
                    else:
                        cur_q = []
                    if ref_df is not None and not df_is_empty(ref_df):
                        ref_q = ref_df.approxQuantile(resolved_col, probs, 0.01)
                    else:
                        ref_q = []
                    if cur_q and ref_q and len(cur_q) == len(ref_q):
                        numeric_metrics["q10_cur"], numeric_metrics["q50_cur"], numeric_metrics["q90_cur"] = cur_q[0], cur_q[1], cur_q[2]
                        numeric_metrics["q10_ref"], numeric_metrics["q50_ref"], numeric_metrics["q90_ref"] = ref_q[0], ref_q[1], ref_q[2]
                        numeric_metrics["median_shift_abs"] = float(cur_q[1] - ref_q[1])
                        numeric_metrics["median_shift_rel"] = (float(cur_q[1] - ref_q[1]) / float(abs(ref_q[1]))) if ref_q[1] != 0 else None
                except Exception:
                    self.logger.exception("Quantile diagnostics failed for numeric column.")
            except Exception:
                self.logger.exception("Error assembling numeric metrics.")

            # Decide drift using thresholds from static config (uses W_THRESH ~ numeric_psi_max as placeholder)
            try:
                # decide: primary threshold use numeric_psi_max as placeholder for wasserstein if not tuned
                w_val = numeric_metrics.get("wasserstein")
                # choose threshold: prefer an explicit per-feature threshold in tbl_conf, else static
                w_thresh = None
                if isinstance(tbl_conf.get("feature_thresholds", {}), dict):
                    w_thresh = tbl_conf.get("feature_thresholds", {}).get(resolved_col)
                if w_thresh is None:
                    # map to static threshold if present; default to 0.05 as conservative start
                    w_thresh = float(self.static_thr.get("numeric_psi_max", 0.20)) if self.static_thr else 0.20
                drift_flag = False
                triggered_metric = None
                if w_val is not None and w_val >= float(w_thresh):
                    drift_flag = True
                    triggered_metric = "wasserstein"
                else:
                    # check median shift (use relative if present)
                    med_rel = numeric_metrics.get("median_shift_rel")
                    med_abs = numeric_metrics.get("median_shift_abs")
                    med_thresh = float(self.static_thr.get("median_shift_pct", 0.10)) if self.static_thr else 0.10
                    if med_rel is not None and abs(med_rel) >= med_thresh:
                        drift_flag = True
                        triggered_metric = "median_shift_rel"
                    elif med_abs is not None and abs(med_abs) >= med_thresh:
                        drift_flag = True
                        triggered_metric = "median_shift_abs"
            except Exception:
                self.logger.exception("Error while deciding numeric drift for feature.")
                drift_flag = False
                triggered_metric = None

            # Emit numeric drift metric row
            try:
                metric_payload = {
                    "run_id": RUN_ID, "run_ts": RUN_TS, "partition_date": partition_date,
                    "metric_type": "numeric_drift", "source_system": sys, "table_name": tbl, "column_name": resolved_col,
                    "metric_value": str(drift_flag), "metric_value_num": float(numeric_metrics.get("wasserstein")) if numeric_metrics.get("wasserstein") is not None else None,
                    "threshold": str(w_thresh), "reference_value": float(numeric_metrics.get("q50_ref")) if numeric_metrics.get("q50_ref") is not None else None,
                    "status": "ALERT" if drift_flag else "PASS", "alert_flag": "Fail" if drift_flag else "Pass", "country": tbl_conf.get("country", REGION)
                }
                self._emit_metric(metric_payload, partition_date=partition_date)
            except Exception:
                self.logger.exception("Failed to emit numeric_drift metric.")

            # emit PSI diagnostic if available
            if psi_val is not None:
                try:
                    psi_payload = {
                        "run_id": RUN_ID, "run_ts": RUN_TS, "partition_date": partition_date,
                        "metric_type": "psi", "source_system": sys, "table_name": tbl, "column_name": resolved_col,
                        "metric_value": f"{psi_val:.6f}", "metric_value_num": float(psi_val), "threshold": f"<={float(self.static_thr.get('numeric_psi_max', 0.20)):.6f}",
                        "reference_value": None, "status": "PASS" if psi_val <= float(self.static_thr.get("numeric_psi_max", 0.20)) else "FAIL",
                        "alert_flag": "Fail" if psi_val > float(self.static_thr.get("numeric_psi_max", 0.20)) else None, "country": tbl_conf.get("country", REGION)
                    }
                    self._emit_metric(psi_payload, partition_date=partition_date)
                except Exception:
                    self.logger.exception("Failed to emit PSI metric.")

        # --------------------
        # Categorical drift
        # --------------------
        categorical_cols_config = tbl_conf.get("categorical_cols") or []
        all_cols = tbl_conf.get("columns", []) or []
        date_cols_set = {c.lower() for c in (tbl_conf.get("date_columns", {}) or {}).keys()}
        numeric_drift_cols_set = {c.lower() for c in (numeric_drift_cols or [])}

        cols_to_check_csi = set()
        for c in categorical_cols_config:
            if isinstance(c, str):
                cols_to_check_csi.add(c)
            elif isinstance(c, dict) and c.get("name"):
                cols_to_check_csi.add(c.get("name"))
        # also include columns not in numeric/date lists
        for c in all_cols:
            if c and c.lower() not in numeric_drift_cols_set and c.lower() not in date_cols_set:
                cols_to_check_csi.add(c)

        for col_name in cols_to_check_csi:
            resolved_col = resolve_col_name(cur_df, col_name)
            if not resolved_col:
                self.logger.warning(f"Categorical column '{col_name}' not found in {sys}.{tbl}.")
                self._mark_uncomputable("csi", sys, tbl, col_name, partition_date=partition_date)
                continue

            chi2_val = None
            cur_map = None
            ref_map = None
            # attempt compute via helper
            try:
                if "compute_categorical_csi_and_buckets" in globals():
                    chi2_val, cur_map, ref_map = compute_categorical_csi_and_buckets(
                        resolved_col, cur_df, ref_df,
                        sample_size=int(self.static_thr.get("drift_sample_size", DRIFT_SAMPLE_SIZE)),
                        max_cardinality=int(self.static_thr.get("max_categorical_cardinality", MAX_CATEGORICAL_CARDINALITY))
                    )
                else:
                    chi2_val, cur_map, ref_map = None, None, None
            except Exception:
                self.logger.exception(f"Error computing categorical buckets for {resolved_col}.")

            # emit current category snapshot if available
            if cur_map is not None:
                try:
                    metric_value_buckets = json.dumps({"categories": list(cur_map.keys()), "counts": list(cur_map.values())}, default=str, ensure_ascii=False)
                    self._emit_metric({
                        "run_id": RUN_ID, "run_ts": RUN_TS, "partition_date": partition_date,
                        "metric_type": "csi_buckets", "source_system": sys, "table_name": tbl, "column_name": resolved_col,
                        "metric_value": metric_value_buckets, "metric_value_num": None, "threshold": None, "reference_value": None,
                        "status": "PASS", "alert_flag": None, "country": tbl_conf.get("country", REGION)
                    }, partition_date=partition_date)
                except Exception:
                    self.logger.exception("Failed to emit csi_buckets metric.")
            else:
                self._mark_uncomputable("csi_buckets", sys, tbl, resolved_col, partition_date=partition_date)

            # fallback using historical csi_buckets if no ref and chi2 missing
            if chi2_val is None and ref_df is None and mh_df is not None:
                try:
                    hist = mh_df.filter(
                        (F.col("table_name") == tbl) &
                        (F.col("metric_type") == "csi_buckets") &
                        (F.col("column_name") == resolved_col)
                    ).orderBy(F.col("partition_date").desc()).limit(MAX_LOOKBACK_SAMPLES).collect()
                    for r in hist:
                        parsed = _parse_buckets_from_metric_value(r["metric_value"])
                        if not parsed:
                            continue
                        fb_cats = parsed.get("categories")
                        fb_ref_counts = parsed.get("counts")
                        if not fb_cats or not fb_ref_counts or len(fb_cats) != len(fb_ref_counts):
                            self.logger.warning(f"Skipping historical CSI record due to invalid data: {r['partition_date']}")
                            continue
                        try:
                            ref_counts_map_fb = {}
                            for i in range(len(fb_cats)):
                                cat_key = fb_cats[i] if fb_cats[i] is not None else "NULL"
                                ref_counts_map_fb[cat_key] = int(fb_ref_counts[i])

                            # compute current counts aligned to historical categories
                            if cur_map is None:
                                try:
                                    cur_rows_fb = cur_df.select(F.coalesce(F.col(resolved_col).cast(StringType()), F.lit("NULL")).alias(resolved_col)) \
                                                       .groupBy(resolved_col).count().withColumnRenamed("count", "cur_cnt").collect()
                                    current_counts_map_fb = {row[resolved_col]: int(row["cur_cnt"]) for row in cur_rows_fb}
                                except Exception:
                                    self.logger.exception("Failed to compute current category counts during CSI fallback.")
                                    continue
                            else:
                                current_counts_map_fb = cur_map

                            chi2_fb = self._compute_chi2_from_maps(current_counts_map_fb, ref_counts_map_fb, EPSILON) if "_compute_chi2_from_maps" in dir(self) else _compute_chi2_from_maps(current_counts_map_fb, ref_counts_map_fb, EPSILON)
                            if chi2_fb is not None:
                                chi2_val = chi2_fb
                                aligned_cats = sorted(list(set(current_counts_map_fb.keys()) | set(ref_counts_map_fb.keys())))
                                cur_aligned_counts = [current_counts_map_fb.get(k, 0) for k in aligned_cats]
                                metric_value_cur_fb = json.dumps({"categories": aligned_cats, "counts": cur_aligned_counts}, default=str, ensure_ascii=False)
                                self._emit_metric({
                                    "run_id": RUN_ID, "run_ts": RUN_TS, "partition_date": partition_date,
                                    "metric_type": "csi_buckets", "source_system": sys, "table_name": tbl, "column_name": resolved_col,
                                    "metric_value": metric_value_cur_fb, "metric_value_num": None, "threshold": None, "reference_value": None,
                                    "status": "PASS", "alert_flag": None, "country": tbl_conf.get("country", REGION)
                                }, partition_date=partition_date)
                                self.logger.info(f"Computed CSI using historical categories from {r['partition_date']}.")
                                break
                        except Exception:
                            self.logger.exception("Error during CSI fallback using historical categories.")
                            continue
                except Exception:
                    self.logger.exception("Error loading historical metrics for CSI fallback.")

            # Now compute primary categorical metrics: JS, entropy delta, top-k churn
            cat_metrics = {}
            try:
                cur_map = cur_map or {}
                ref_map = ref_map or {}
                # if ref_map exists and cardinality acceptable, compute JS
                cardinality = len(set(cur_map.keys()) | set(ref_map.keys()))
                max_card = int(self.static_thr.get("max_categorical_cardinality", MAX_CATEGORICAL_CARDINALITY)) if self.static_thr else MAX_CATEGORICAL_CARDINALITY
                if ref_map and cardinality <= max_card:
                    cat_metrics["js"] = (_jensen_shannon_from_maps(cur_map, ref_map, top_k=50) if "_jensen_shannon_from_maps" not in dir(self) else (self._jensen_shannon_from_maps(cur_map, ref_map, top_k=50)))
                    ent_cur = _entropy_from_map(cur_map)
                    ent_ref = _entropy_from_map(ref_map)
                    cat_metrics["entropy_cur"] = ent_cur
                    cat_metrics["entropy_ref"] = ent_ref
                    cat_metrics["entropy_delta"] = ent_cur - ent_ref
                    mass_delta, removed, added, churn = _topk_churn_and_mass_change(cur_map, ref_map, k=min(50, max_card))
                    cat_metrics["topk_mass_delta"] = mass_delta
                    cat_metrics["topk_removed"] = removed
                    cat_metrics["topk_added"] = added
                    cat_metrics["topk_churn"] = churn
                else:
                    # high-cardinality fallback: only top-k diagnostics (ref_map may be empty)
                    mass_delta, removed, added, churn = _topk_churn_and_mass_change(cur_map, ref_map, k=50)
                    cat_metrics["topk_mass_delta"] = mass_delta
                    cat_metrics["topk_removed"] = removed
                    cat_metrics["topk_added"] = added
                    cat_metrics["topk_churn"] = churn
            except Exception:
                self.logger.exception("Error assembling categorical metrics.")

            # Decide categorical drift using static thresholds
            try:
                js_val = cat_metrics.get("js")
                js_thresh = float(self.static_thr.get("categorical_csi_max", 50.0)) if self.static_thr else 50.0
                drift_flag_cat = False
                triggered_metric_cat = None
                if js_val is not None and js_val >= js_thresh:
                    drift_flag_cat = True
                    triggered_metric_cat = "js"
                else:
                    topk_mass = cat_metrics.get("topk_mass_delta")
                    topk_churn = cat_metrics.get("topk_churn")
                    topk_mass_thresh = float(self.static_thr.get("topk_mass_delta", 0.05)) if self.static_thr else 0.05
                    topk_churn_thresh = float(self.static_thr.get("topk_churn", 0.3)) if self.static_thr else 0.3
                    if topk_mass is not None and abs(topk_mass) >= topk_mass_thresh:
                        drift_flag_cat = True
                        triggered_metric_cat = "topk_mass"
                    elif topk_churn is not None and topk_churn >= topk_churn_thresh:
                        drift_flag_cat = True
                        triggered_metric_cat = "topk_churn"
            except Exception:
                self.logger.exception("Error while deciding categorical drift.")
                drift_flag_cat = False
                triggered_metric_cat = None

            # Emit categorical drift metric row
            try:
                metric_payload = {
                    "run_id": RUN_ID, "run_ts": RUN_TS, "partition_date": partition_date,
                    "metric_type": "categorical_drift", "source_system": sys, "table_name": tbl, "column_name": resolved_col,
                    "metric_value": str(drift_flag_cat), "metric_value_num": float(cat_metrics.get("js")) if cat_metrics.get("js") is not None else None,
                    "threshold": str(js_thresh) if 'js_thresh' in locals() else None, "reference_value": None,
                    "status": "ALERT" if drift_flag_cat else "PASS", "alert_flag": "Fail" if drift_flag_cat else "Pass", "country": tbl_conf.get("country", REGION)
                }
                self._emit_metric(metric_payload, partition_date=partition_date)
            except Exception:
                self.logger.exception("Failed to emit categorical_drift metric.")

            # emit chi2 diagnostic if available
            if chi2_val is not None:
                try:
                    chi2_payload = {
                        "run_id": RUN_ID, "run_ts": RUN_TS, "partition_date": partition_date,
                        "metric_type": "csi", "source_system": sys, "table_name": tbl, "column_name": resolved_col,
                        "metric_value": f"{chi2_val:.6f}", "metric_value_num": float(chi2_val),
                        "threshold": f"<={float(self.static_thr.get('categorical_csi_max', 50.0)):.6f}",
                        "reference_value": None, "status": "PASS" if chi2_val <= float(self.static_thr.get('categorical_csi_max', 50.0)) else "FAIL",
                        "alert_flag": "Fail" if chi2_val > float(self.static_thr.get('categorical_csi_max', 50.0)) else None, "country": tbl_conf.get("country", REGION)
                    }
                    self._emit_metric(chi2_payload, partition_date=partition_date)
                except Exception:
                    self.logger.exception("Failed to emit CSI (chi2) metric.")
