# Correlation Analysis: Distribution Drift (KS/JS) vs Model AUROC

This script computes correlations between **model distribution drift** (on scores & labels) and **model performance (AUROC)** for three tasks: **mortality, length-of-stay (LOS), and 30-day readmission**.

---

## 1. Inputs

Located in the `results/` directory:

- **Scored predictions** (row-level):  
  - `mortality_scored_predictions.csv`  
  - `los_scored_predictions.csv`  
  - `readmit30_scored_predictions.csv`  
  Each file must include:  
  - `split_date` → which split this row belongs to  
  - `admitdatetime` → timestamp (used to form weekly periods)  
  - `y_true` → true binary label (0/1)  
  - `y_pred` → predicted probability (0–1)

- **Weekly drift CSVs** (already produced by the drift monitoring script):  
  - `mortality_distribution_drift_weekly.csv`  
  - `los_distribution_drift_weekly.csv`  
  - `readmit30_distribution_drift_weekly.csv`  
  Each file must include:  
  - `split_date`, `week` (period start date)  
  - `drift_score_KS` → Kolmogorov–Smirnov statistic on scores  
  - `drift_label_JS` → Jensen–Shannon divergence on label distributions  
  - (Optional: `drift_score_PSI` can also be used)

---

## 2. Processing Steps

### a. Compute Weekly AUROC
For each task:
1. Add a `period` column (weekly bins of `admitdatetime`).
2. For each `split_date` × `period`:
   - Compute **AUROC** from `y_true` vs `y_pred`.
   - Skip periods with only one class (AUROC undefined).

**Output:** DataFrame with `split_date, week, n_week, auroc`.

---

### b. Align Drift with AUROC
- Merge drift (`drift_score_KS`, `drift_label_JS`) with AUROC by `split_date` and `week`.

---

### c. Correlation Computation
For each `metric ∈ {drift_score_KS, drift_label_JS}`:

1. **Static correlations (lag=0):**  
   - Pearson and Spearman correlations between drift and AUROC.

2. **Lagged correlations (±6 weeks):**  
   - Positive lag → **drift leads AUROC** (drift potentially predicts AUROC drop).  
   - Negative lag → **AUROC leads drift**.

**Metrics stored:**  
- `split_date`, `metric`, `lag_weeks`, `pearson`, `spearman`, `n_pairs`.

---

## 3. Outputs

- Per-task correlation tables:
  - `mortality_ks_js_auroc_correlations.csv`  
  - `los_ks_js_auroc_correlations.csv`  
  - `readmit30_ks_js_auroc_correlations.csv`

- Combined table across all tasks:
  - `correlations_all_tasks.csv`

Columns:
- `task` → which task (mortality/los/readmit30)  
- `split_date` → training split identifier  
- `metric` → `"SCORE_KS"` or `"LABEL_JS"`  
- `lag_weeks` → integer, positive means drift leads AUROC  
- `pearson`, `spearman` → correlation coefficients  
- `n_pairs` → number of aligned weeks used

---

## 4. Interpretation

- **High negative correlation (Pearson < 0) at lag=0:**  
  Drift increases are associated with lower AUROC **at the same time**.

- **High negative correlation at positive lag:**  
  Drift changes **precede** AUROC degradation → early warning signal.

- **Spearman vs Pearson:**  
  - Pearson = linear relationship.  
  - Spearman = monotonic (rank-based), robust to nonlinear patterns.

---

## 5. Example Workflow

1. Run drift monitoring script → generates `{task}_distribution_drift_weekly.csv`.
2. Run this correlation script → generates per-task + combined CSVs.
3. Inspect:
   - Static correlations → snapshot of drift ↔ performance alignment.  
   - Lagged correlations → whether drift can **predict future AUROC shifts**.

---

## 6. Key Parameters

| Parameter | Default | Meaning |
|-----------|---------|---------|
| `TASKS` | `["mortality","los","readmit30"]` | Which tasks to process. |
| `TIME_FREQ` | `"W"` | Period frequency (weekly). Must match drift CSV. |
| `MAX_LAG` | `6` | Max lag (weeks) for correlation analysis. |

---

## 7. Outputs for Visualization

- CSV tables can be used to:
  - Plot **correlation vs lag_weeks** curves.
  - Compare across tasks (mortality vs LOS vs readmit30).
  - Highlight **negative peaks at positive lags** (potential early warnings).





In [None]:
# ============================================================
# Correlation Analysis: Distribution Drift (KS/JS) vs AUROC
# Loops over tasks: mortality, los, readmit30
#
# Inputs expected under ./results:
#   - <task>_scored_predictions.csv           (columns: split_date, admitdatetime, y_true, y_pred)
#   - <task>_distribution_drift_weekly.csv    (columns: split_date, week, drift_label_JS, drift_score_KS, ...)
#
# Outputs to ./results:
#   - <task>_ks_js_auroc_correlations.csv
#   - correlations_all_tasks.csv  (combined)
# ============================================================

import os, io, sys
import numpy as np
import pandas as pd
from pandas.errors import EmptyDataError, ParserError
from sklearn.metrics import roc_auc_score
from scipy.stats import pearsonr, spearmanr

# -------------------
# Config
# -------------------
RESULTS_DIR = "results"
TASKS = ["mortality", "los", "readmit30"]     # must match your file prefixes
TIME_FREQ = "W"                               # should match how "week" was created in drift CSVs
MAX_LAG = 6                                   # ± lag (in periods, e.g., weeks)
SAVE_COMBINED = True

# Columns in the drift CSV we’ll correlate with AUROC
DRIFT_KS_COL = "drift_score_KS"
DRIFT_JS_COL = "drift_label_JS"

# -------------------
# Safe I/O helpers
# -------------------
def safe_read_csv(path, parse_dates=None):
    """
    Robust CSV reader: handles missing/empty/zero-byte files gracefully.
    Returns (df, ok_bool). If ok_bool=False, df is empty.
    """
    if not os.path.exists(path):
        print(f"[WARN] File not found: {path}")
        return pd.DataFrame(), False
    try:
        if os.path.getsize(path) == 0:
            print(f"[WARN] Zero-byte file: {path}")
            return pd.DataFrame(), False
    except OSError as e:
        print(f"[WARN] Could not stat file {path}: {e}")
        return pd.DataFrame(), False

    try:
        df = pd.read_csv(path, parse_dates=parse_dates)
    except EmptyDataError:
        print(f"[WARN] EmptyDataError: {path}")
        return pd.DataFrame(), False
    except ParserError as e:
        print(f"[WARN] ParserError on {path}: {e}")
        return pd.DataFrame(), False
    except UnicodeDecodeError:
        # Try reading as binary then decode fallback
        with open(path, "rb") as f:
            raw = f.read()
        try:
            df = pd.read_csv(io.StringIO(raw.decode("utf-8", errors="ignore")), parse_dates=parse_dates)
        except Exception as e:
            print(f"[WARN] Fallback read failed for {path}: {e}")
            return pd.DataFrame(), False
    return df, True

def atomic_write_csv(df, path):
    """Write CSV atomically to avoid zero-byte partial files."""
    tmp = path + ".tmp"
    df.to_csv(tmp, index=False)
    os.replace(tmp, path)

# -------------------
# AUROC computation
# -------------------
def _add_period(df, date_col="admitdatetime", freq="W"):
    out = df.copy()
    out[date_col] = pd.to_datetime(out[date_col], errors="coerce")
    out["period"] = out[date_col].dt.to_period(freq).dt.start_time
    return out

def compute_weekly_auroc(scored_df,
                         date_col="admitdatetime",
                         split_col="split_date",
                         y_col="y_true",
                         p_col="y_pred",
                         freq="W"):
    """
    Returns DataFrame with columns: split_date, week, n_week, auroc
    Skips periods that have only a single class (AUROC undefined).
    """
    if scored_df.empty:
        return pd.DataFrame(columns=["split_date","week","n_week","auroc"])

    df = _add_period(scored_df, date_col=date_col, freq=freq)
    req = {split_col, "period", y_col, p_col}
    miss = req - set(df.columns)
    if miss:
        raise ValueError(f"compute_weekly_auroc: missing columns {miss}")

    rows = []
    for split, gsplit in df.groupby(split_col):
        for per, g in gsplit.groupby("period"):
            y = g[y_col].dropna().astype(int).to_numpy()
            p = g[p_col].dropna().astype(float).to_numpy()
            # align after dropna
            # ensure same length
            n = min(len(y), len(p))
            if n == 0:
                continue
            y, p = y[:n], p[:n]

            if len(np.unique(y)) < 2:
                # AUROC undefined for single-class period
                continue
            try:
                auc = float(roc_auc_score(y, p))
            except Exception:
                continue

            rows.append({
                "split_date": str(split),
                "week": pd.to_datetime(per),
                "n_week": int(len(y)),
                "auroc": auc
            })
    out = pd.DataFrame(rows, columns=["split_date","week","n_week","auroc"])
    return out.sort_values(["split_date","week"]).reset_index(drop=True)

# -------------------
# Correlation logic
# -------------------
def _corr_pair(x, y):
    """Return (pearson, spearman, n) with NaNs handled; returns (None,None,0) if not enough pairs."""
    x = pd.Series(x).astype(float)
    y = pd.Series(y).astype(float)
    mask = x.notna() & y.notna()
    x = x[mask]; y = y[mask]
    n = int(len(x))
    if n < 3:
        return None, None, n
    try:
        pr = pearsonr(x, y)[0]
    except Exception:
        pr = None
    try:
        sr = spearmanr(x, y, nan_policy="omit").correlation
    except Exception:
        sr = None
    return (float(pr) if pr is not None and np.isfinite(pr) else None,
            float(sr) if sr is not None and np.isfinite(sr) else None,
            n)

def correlate_drift_with_auroc(drift_df, auroc_df, metrics=("drift_score_KS","drift_label_JS"), max_lag=6):
    """
    Per split_date: compute static (lag 0) and lagged (±max_lag) correlations between drift metrics and AUROC.
    Returns a long DataFrame with columns: split_date, metric, lag_weeks, pearson, spearman, n_pairs
    """
    if drift_df.empty or auroc_df.empty:
        return pd.DataFrame(columns=["split_date","metric","lag_weeks","pearson","spearman","n_pairs"])

    # Normalize types
    drift = drift_df.copy()
    drift["split_date"] = drift["split_date"].astype(str)
    auroc = auroc_df.copy()
    auroc["split_date"] = auroc["split_date"].astype(str)

    rows = []
    for split in sorted(set(drift["split_date"]).intersection(set(auroc["split_date"]))):
        d = drift[drift["split_date"] == split].copy()
        a = auroc[auroc["split_date"] == split].copy()

        # Join on week (ensure datetime)
        d["week"] = pd.to_datetime(d["week"], errors="coerce")
        a["week"] = pd.to_datetime(a["week"], errors="coerce")

        # Base merged frame at lag 0
        base = pd.merge(d[["week"] + list(metrics)], a[["week","auroc"]], on="week", how="inner").sort_values("week")
        if base.empty:
            continue

        # For each metric and lag
        for metric in metrics:
            # Static (lag 0)
            pr, sr, n = _corr_pair(base[metric], base["auroc"])
            rows.append({
                "split_date": split,
                "metric": metric.replace("drift_","").upper(),
                "lag_weeks": 0,
                "pearson": pr, "spearman": sr, "n_pairs": n
            })

            # Lagged: positive lag => drift leads AUROC
            for lag in range(1, max_lag + 1):
                # Drift leads AUROC: AUROC shifted backward (future) relative to drift
                lead = base.copy()
                lead["auroc_shifted"] = lead["auroc"].shift(-lag)
                pr, sr, n = _corr_pair(lead[metric], lead["auroc_shifted"])
                rows.append({
                    "split_date": split, "metric": metric.replace("drift_","").upper(),
                    "lag_weeks": +lag, "pearson": pr, "spearman": sr, "n_pairs": n
                })

                # AUROC leads drift: AUROC shifted forward (past) relative to drift
                lagg = base.copy()
                lagg["auroc_shifted"] = lagg["auroc"].shift(+lag)
                pr, sr, n = _corr_pair(lagg[metric], lagg["auroc_shifted"])
                rows.append({
                    "split_date": split, "metric": metric.replace("drift_","").upper(),
                    "lag_weeks": -lag, "pearson": pr, "spearman": sr, "n_pairs": n
                })

    out = pd.DataFrame(rows, columns=["split_date","metric","lag_weeks","pearson","spearman","n_pairs"])
    return out.sort_values(["metric","split_date","lag_weeks"]).reset_index(drop=True)

# -------------------
# Driver
# -------------------
def main():
    os.makedirs(RESULTS_DIR, exist_ok=True)
    combined = []

    for task in TASKS:
        print(f"\n==== {task.upper()} ====")

        # 1) Load drift CSV (produced by your moving-split drift script)
        drift_path = os.path.join(RESULTS_DIR, f"{task}_distribution_drift_weekly.csv")
        drift_df, ok = safe_read_csv(drift_path, parse_dates=["week"])
        if (not ok) or drift_df.empty:
            print(f"[WARN] Missing/empty drift CSV for '{task}': {drift_path}")
            continue

        # Guard columns
        need = {"split_date","week", DRIFT_KS_COL, DRIFT_JS_COL}
        missing = need - set(drift_df.columns)
        if missing:
            print(f"[WARN] Drift CSV for '{task}' missing columns: {missing}. Skipping task.")
            continue

        # 2) Load scored predictions and compute weekly AUROC
        scored_path = os.path.join(RESULTS_DIR, f"{task}_scored_predictions.csv")
        scored_df, ok = safe_read_csv(scored_path, parse_dates=["admitdatetime"])
        if (not ok) or scored_df.empty:
            print(f"[WARN] Missing/empty scored predictions for '{task}': {scored_path}")
            continue

        # Guard columns
        need_scored = {"split_date","admitdatetime","y_true","y_pred"}
        missing_scored = need_scored - set(scored_df.columns)
        if missing_scored:
            print(f"[WARN] Scored predictions for '{task}' missing columns: {missing_scored}. Skipping task.")
            continue

        auroc_df = compute_weekly_auroc(
            scored_df,
            date_col="admitdatetime",
            split_col="split_date",
            y_col="y_true",
            p_col="y_pred",
            freq=TIME_FREQ
        )
        if auroc_df.empty:
            print(f"[WARN] No weekly AUROC rows for '{task}' (single-class weeks or no overlap).")
            continue

        # 3) Correlate (per split_date), static + lagged
        corr_df = correlate_drift_with_auroc(
            drift_df=drift_df,
            auroc_df=auroc_df,
            metrics=(DRIFT_KS_COL, DRIFT_JS_COL),
            max_lag=MAX_LAG
        )
        if corr_df.empty:
            print(f"[WARN] No overlapping weeks between drift and AUROC for '{task}'.")
            continue

        # 4) Save per-task CSV
        out_task = os.path.join(RESULTS_DIR, f"{task}_ks_js_auroc_correlations.csv")
        atomic_write_csv(corr_df, out_task)
        print(f"[OK] Saved correlations → {out_task} (rows={len(corr_df)})")

        # Quick console preview: static (lag=0)
        static_preview = corr_df[corr_df["lag_weeks"] == 0].copy()
        if not static_preview.empty:
            print("\n-- Static correlations (lag=0) --")
            print(static_preview[["split_date","metric","pearson","spearman","n_pairs"]].to_string(index=False))

        corr_df = corr_df.copy()
        corr_df.insert(0, "task", task)
        combined.append(corr_df)

    # 5) Save combined CSV across tasks
    if SAVE_COMBINED and combined:
        all_df = pd.concat(combined, ignore_index=True)
        out_all = os.path.join(RESULTS_DIR, "correlations_all_tasks.csv")
        atomic_write_csv(all_df, out_all)
        print(f"\n[OK] Saved combined correlations → {out_all} (rows={len(all_df)})")
        # Optional: quick best-negative-corr preview at positive lags (drift leads)
        try:
            lead = all_df[all_df["lag_weeks"] > 0].copy()
            if not lead.empty:
                best = (lead.sort_values(["metric","pearson"], ascending=[True, True])
                             .groupby(["task","metric"]).head(3))
                print("\n-- Strongest (more negative) Pearson at positive lags (top 3 per task/metric) --")
                print(best[["task","metric","split_date","lag_weeks","pearson","spearman","n_pairs"]].to_string(index=False))
        except Exception:
            pass

    print("\n[DONE] Correlation analysis complete.")

if __name__ == "__main__":
    main()



