# Drift Diagnostics for MDE Category Predictions

This notebook inspects multi-delay embedding (MDE) category predictions alongside the underlying Schaefer ROI time series. It helps you quantify low-frequency drift, relate it to dominant brain components, and try simple mitigation strategies such as projecting out drift-prone principal components (PCs).

## How to use this notebook

1. Adjust the configuration cell (paths, subject/story, smoothing config, and categories).
2. Run the data loading cell to gather predictions and compute principal components of the brain activity.
3. Review the automatically generated summaries and plots to identify drift sources.
4. Optionally project out selected PCs and (if desired) write corrected prediction CSVs back to disk.

In [None]:
from __future__ import annotations

import json
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd


plt.style.use("seaborn-v0_8")
plt.rcParams.update({
    "figure.dpi": 120,
    "axes.spines.top": False,
    "axes.spines.right": False,
})
pd.set_option("display.max_rows", 200)
pd.set_option("display.float_format", lambda v: f"{v:0.4f}")

In [None]:
# ==== configuration ====
PROJECT_ROOT = Path("..").resolve()

SUBJECT = "UTS01"
STORY = "wheretheressmoke"

# Day26 smoothing directory that holds per-category folders
PREDICTION_ROOT = PROJECT_ROOT / "figs" / SUBJECT / STORY / "day26_smoothing_cli_MDE50step"

# Brain activity cache (Schaefer ROI time series)
BRAIN_TS_PATH = PROJECT_ROOT / "data_cache" / SUBJECT / STORY / "schaefer_400.npy"
CACHE_META_PATH = PROJECT_ROOT / "data_cache" / SUBJECT / STORY / "cache_meta.json"

# Category subset and smoothing configuration
SMOOTHING_DIR = "gauss_1p00"  # e.g. "gauss_1p00", "movavg_1p00", or "none_0p00"
CATEGORY_NAMES: Optional[Sequence[str]] = None  # e.g. ["cat_temporal_", "cat_locational_"]

# PCA / plotting parameters
MAX_PCS = 6
PLOT_PC_LIMIT = 6

# Mitigation settings
DRIFT_PC_COUNT = 1  # number of leading PCs to project out when correcting
PC_CORRECTION_TARGETS: Optional[Sequence[str]] = None  # None -> apply correction to all loaded categories
SAVE_CORRECTED = False
CORRECTED_SUFFIX = "driftcorrected"


In [None]:
# ==== helper functions ====
def discover_categories(prediction_root: Path, explicit: Optional[Sequence[str]] = None) -> List[str]:
    if explicit:
        return [str(name) for name in explicit]
    if not prediction_root.exists():
        raise FileNotFoundError(f"Prediction root not found: {prediction_root}")
    return sorted(
        entry.name
        for entry in prediction_root.iterdir()
        if entry.is_dir() and entry.name.startswith("cat_")
    )


def load_prediction(
    prediction_root: Path,
    category: str,
    smoothing_dir: str,
    subject: str,
    story: str,
) -> Tuple[pd.DataFrame, Path]:
    cat_dir = prediction_root / category
    if not cat_dir.exists():
        raise FileNotFoundError(f"Missing category folder: {cat_dir}")
    config_dir = cat_dir / smoothing_dir if smoothing_dir else cat_dir
    if not config_dir.exists():
        raise FileNotFoundError(f"Missing smoothing folder for {category}: {config_dir}")

    safe_name = category.rstrip("_")
    candidate = (
        config_dir
        / subject
        / story
        / "day22_category_mde"
        / f"mde_{safe_name}_best_prediction.csv"
    )
    if not candidate.exists():
        matches = list(config_dir.glob(f"**/mde_{safe_name}_best_prediction.csv"))
        if not matches:
            matches = list(config_dir.glob("**/mde_*_best_prediction.csv"))
        if not matches:
            raise FileNotFoundError(
                f"No best-prediction CSV found for {category} under {config_dir}"
            )
        candidate = sorted(matches)[0]

    df = pd.read_csv(candidate)
    df = df.copy()
    for col in ("trim_index", "start_sec", "time", "target", "prediction"):
        if col in df:
            df[col] = pd.to_numeric(df[col], errors="coerce")
    df.sort_values("trim_index", inplace=True, ignore_index=True)
    return df, candidate


def load_cache_meta(path: Path) -> Dict[str, Any]:
    if not path.exists():
        raise FileNotFoundError(f"Cache metadata not found: {path}")
    with path.open("r", encoding="utf-8") as fh:
        return json.load(fh)


def load_brain_timeseries(path: Path) -> np.ndarray:
    if not path.exists():
        raise FileNotFoundError(f"Brain timeseries not found: {path}")
    arr = np.load(path)
    if arr.ndim != 2:
        raise ValueError(f"Expected 2D array for brain timeseries, found shape {arr.shape}")
    mask = np.isfinite(arr)
    if not np.all(mask):
        # Replace NaNs/infs with column medians to keep PCA stable
        col_median = np.nanmedian(arr, axis=0, keepdims=True)
        arr = np.where(mask, arr, col_median)
    return arr.astype(float)


def compute_pca(matrix: np.ndarray, n_components: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    if matrix.size == 0:
        raise ValueError("Cannot compute PCA on an empty matrix")
    X = np.asarray(matrix, dtype=float)
    X = X - np.nanmean(X, axis=0, keepdims=True)
    U, S, Vt = np.linalg.svd(X, full_matrices=False)
    if n_components <= 0:
        n_components = Vt.shape[0]
    keep = min(n_components, Vt.shape[0])
    scores = U[:, :keep] * S[:keep]
    components = Vt[:keep]
    if X.shape[0] > 1:
        explained_raw = (S ** 2) / (X.shape[0] - 1)
        total = explained_raw.sum()
        explained = explained_raw[:keep] / total if total > 0 else np.zeros(keep)
    else:
        explained = np.zeros(keep)
    return scores, components, explained


def build_pc_dataframe(scores: np.ndarray, start_sec: np.ndarray) -> pd.DataFrame:
    data = {f"PC{i+1}": scores[:, i] for i in range(scores.shape[1])}
    data["start_sec"] = start_sec
    data["tr_index"] = np.arange(len(start_sec), dtype=int)
    return pd.DataFrame(data)


def compute_linear_trend(x: Iterable[float], y: Iterable[float]) -> Dict[str, float]:
    x_arr = np.asarray(list(x), dtype=float)
    y_arr = np.asarray(list(y), dtype=float)
    mask = np.isfinite(x_arr) & np.isfinite(y_arr)
    if mask.sum() < 3:
        return {key: np.nan for key in ("slope", "intercept", "delta", "r2", "slope_per_min")}
    x_valid = x_arr[mask]
    y_valid = y_arr[mask]
    slope, intercept = np.polyfit(x_valid, y_valid, 1)
    fitted = slope * x_valid + intercept
    resid = y_valid - fitted
    var_y = np.var(y_valid)
    var_resid = np.var(resid)
    r2 = 1 - (var_resid / var_y) if var_y > 0 else np.nan
    delta = slope * (x_valid.max() - x_valid.min())
    slope_per_min = slope * 60.0
    return {
        "slope": slope,
        "intercept": intercept,
        "delta": delta,
        "r2": r2,
        "slope_per_min": slope_per_min,
    }


def safe_corr(a: Iterable[float], b: Iterable[float]) -> float:
    a_arr = np.asarray(list(a), dtype=float)
    b_arr = np.asarray(list(b), dtype=float)
    mask = np.isfinite(a_arr) & np.isfinite(b_arr)
    if mask.sum() < 3:
        return float("nan")
    cov = np.corrcoef(a_arr[mask], b_arr[mask])
    return float(cov[0, 1])


def project_out_components(values: np.ndarray, basis: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    y = np.asarray(values, dtype=float)
    X = np.asarray(basis, dtype=float)
    if X.ndim == 1:
        X = X[:, None]
    mask = np.isfinite(y) & np.all(np.isfinite(X), axis=1)
    correction = np.full(y.shape, np.nan, dtype=float)
    coeffs = np.full(X.shape[1], np.nan, dtype=float)
    if mask.sum() >= X.shape[1] and mask.any():
        coeffs, *_ = np.linalg.lstsq(X[mask], y[mask], rcond=None)
        correction_full = X @ coeffs
        correction[mask] = correction_full[mask]
    return correction, coeffs


def build_category_summary(
    category_store: Dict[str, Dict[str, Any]],
    max_pc: int,
) -> pd.DataFrame:
    rows: List[Dict[str, Any]] = []
    for cat, payload in category_store.items():
        merged = payload["merged"]
        trend_pred = compute_linear_trend(merged["start_sec"], merged["prediction"])
        trend_target = compute_linear_trend(merged["start_sec"], merged["target"])
        trend_resid = compute_linear_trend(merged["start_sec"], merged["residual"])
        row: Dict[str, Any] = {
            "category": cat,
            "n_samples": merged.shape[0],
            "start_sec_min": float(np.nanmin(merged["start_sec"])),
            "start_sec_max": float(np.nanmax(merged["start_sec"])),
            "pred_slope": trend_pred["slope"],
            "pred_slope_per_min": trend_pred["slope_per_min"],
            "pred_delta": trend_pred["delta"],
            "pred_r2_trend": trend_pred["r2"],
            "target_slope": trend_target["slope"],
            "target_slope_per_min": trend_target["slope_per_min"],
            "target_delta": trend_target["delta"],
            "residual_slope": trend_resid["slope"],
            "residual_slope_per_min": trend_resid["slope_per_min"],
            "residual_r2_trend": trend_resid["r2"],
        }
        for idx in range(max_pc):
            col = f"PC{idx + 1}"
            if col in merged:
                row[f"corr_{col}"] = safe_corr(merged["prediction"], merged[col])
        if "merged_corrected" in payload:
            corrected = payload["merged_corrected"]
            trend_corr = compute_linear_trend(
                corrected["start_sec"], corrected["prediction_corrected"]
            )
            row.update(
                {
                    "pred_corrected_slope": trend_corr["slope"],
                    "pred_corrected_slope_per_min": trend_corr["slope_per_min"],
                    "pred_corrected_delta": trend_corr["delta"],
                    "pred_corrected_r2_trend": trend_corr["r2"],
                }
            )
        else:
            row.update(
                {
                    "pred_corrected_slope": np.nan,
                    "pred_corrected_slope_per_min": np.nan,
                    "pred_corrected_delta": np.nan,
                    "pred_corrected_r2_trend": np.nan,
                }
            )
        rows.append(row)
    if not rows:
        return pd.DataFrame()
    summary = pd.DataFrame(rows)
    summary.set_index("category", inplace=True)
    if "pred_slope_per_min" in summary:
        summary.sort_values(
            by="pred_slope_per_min",
            key=lambda col: np.abs(col.fillna(0.0)),
            ascending=False,
            inplace=True,
        )
    return summary

In [None]:
# ==== load predictions and brain PCs ====
          cache_meta = load_cache_meta(CACHE_META_PATH)
          tr = float(cache_meta.get("tr", 2.0))
          print(f"Loaded cache metadata from {CACHE_META_PATH} (TR = {tr}s)")

          brain_ts = load_brain_timeseries(BRAIN_TS_PATH)
          time_axis = np.arange(brain_ts.shape[0], dtype=float) * tr
          print(f"Brain timeseries shape: {brain_ts.shape}")

          pc_scores, pc_components, pc_explained = compute_pca(brain_ts, MAX_PCS)
          pc_df = build_pc_dataframe(pc_scores, time_axis)
          pc_df["time_key"] = (pc_df["start_sec"] * 1000).round().astype(int)
          print(f"Computed {pc_scores.shape[1]} principal components (explained var shown below)")

          categories = discover_categories(PREDICTION_ROOT, CATEGORY_NAMES)
          print(f"Scanning categories (smoothing='{SMOOTHING_DIR}'):
" + ", ".join(categories))

          category_store: Dict[str, Dict[str, Any]] = {}
          for cat in categories:
              try:
                  pred_df, csv_path = load_prediction(PREDICTION_ROOT, cat, SMOOTHING_DIR, SUBJECT, STORY)
              except FileNotFoundError as exc:
                  print(f"[warning] {exc}")
                  continue
              pred_df["time_key"] = (pred_df["start_sec"] * 1000).round().astype(int)
              merged = pred_df.merge(pc_df, how="left", on="time_key", suffixes=("", "_pc"))
              merged["residual"] = merged["prediction"] - merged["target"]
              category_store[cat] = {
                  "csv_path": csv_path,
                  "raw": pred_df,
                  "merged": merged,
              }
          print(f"Loaded {len(category_store)} categories with aligned PCs")

In [None]:
# ==== principal component diagnostics ====
            pc_summary_rows = []
            for idx in range(pc_scores.shape[1]):
                col = f"PC{idx + 1}"
                metrics = compute_linear_trend(pc_df["start_sec"], pc_df[col])
                metrics.update(
                    {
                        "pc": col,
                        "explained_var": pc_explained[idx],
                    }
                )
                pc_summary_rows.append(metrics)

            pc_summary_df = pd.DataFrame(pc_summary_rows).set_index("pc")
            display(pc_summary_df)

            limit = min(PLOT_PC_LIMIT, pc_scores.shape[1])
            fig, axes = plt.subplots(limit, 1, figsize=(12, max(2.2 * limit, 3.5)), sharex=True)
            if limit == 1:
                axes = [axes]
            for ax, idx in zip(axes, range(limit)):
                col = f"PC{idx + 1}"
                series = pc_df[col]
                ax.plot(pc_df["start_sec"], series, label=col)
                trend = compute_linear_trend(pc_df["start_sec"], series)
                if np.isfinite(trend["slope"]):
                    fitted = trend["intercept"] + trend["slope"] * pc_df["start_sec"]
                    ax.plot(pc_df["start_sec"], fitted, linestyle="--", color="black", linewidth=1.0)
                    ax.text(
                        0.01,
                        0.95,
                        f"slope/min={trend['slope_per_min']:.4f}
explained={pc_explained[idx]:.3f}",
                        transform=ax.transAxes,
                        va="top",
                        ha="left",
                        fontsize=9,
                        bbox=dict(boxstyle="round", facecolor="white", alpha=0.7, edgecolor="none"),
                    )
                ax.set_ylabel(col)
                ax.grid(alpha=0.25)
            axes[-1].set_xlabel("Time (s)")
            fig.suptitle("Schaefer ROI principal component time courses", y=0.995)
            fig.tight_layout()
            plt.show()

In [None]:
# ==== category-level drift summary (pre/post correction) ====
category_summary_df = build_category_summary(category_store, pc_scores.shape[1])
if category_summary_df.empty:
    print("No categories loaded. Check configuration paths.")
else:
    display(category_summary_df)

### Inspect a single category

Update `category_to_plot` below to visualise the target, prediction, and (if present) drift-corrected series for one category.

In [None]:
category_to_plot = "cat_temporal_"

payload = category_store.get(category_to_plot)
if payload is None:
    print(f"Category '{category_to_plot}' not found. Available: {list(category_store)}")
else:
    merged = payload.get("merged_corrected", payload["merged"])
    fig, ax = plt.subplots(figsize=(12, 4))
    ax.plot(merged["start_sec"], merged["target"], label="Target", linewidth=1.3)
    ax.plot(merged["start_sec"], merged["prediction"], label="Prediction", linewidth=1.1)
    if "prediction_corrected" in merged:
        ax.plot(
            merged["start_sec"],
            merged["prediction_corrected"],
            label="Prediction (corrected)",
            linewidth=1.1,
            linestyle="--",
        )
    ax.set_title(f"{category_to_plot} – predictions vs target")
    ax.set_xlabel("Start time (s)")
    ax.set_ylabel("Value")
    ax.legend()
    ax.grid(alpha=0.3)
    fig.tight_layout()
    plt.show()

In [None]:
# ==== drift mitigation: project out leading PCs ====
correction_reports: List[Dict[str, Any]] = []
pc_cols = [f"PC{i}" for i in range(1, DRIFT_PC_COUNT + 1)]

if DRIFT_PC_COUNT <= 0:
    print("DRIFT_PC_COUNT set to 0; predictions will be copied without adjustment.")

target_categories = (
    [str(cat) for cat in PC_CORRECTION_TARGETS]
    if PC_CORRECTION_TARGETS
    else list(category_store.keys())
)

if not target_categories:
    print("No categories available for correction. Ensure category_store is populated.")
else:
    for cat in sorted(target_categories):
        payload = category_store.get(cat)
        if payload is None:
            print(f"[skip] Category '{cat}' not loaded.")
            continue

        merged = payload["merged"].copy()
        available_pc_cols = [col for col in pc_cols if col in merged]

        correction_applied = False
        coeffs_list: List[float] = []

        if available_pc_cols and DRIFT_PC_COUNT > 0:
            correction, coeffs = project_out_components(
                merged["prediction"].to_numpy(),
                merged[available_pc_cols].to_numpy(),
            )
            coeffs_list = coeffs.tolist() if getattr(coeffs, "size", 0) else []
            if np.isfinite(correction).any():
                correction_applied = True
            correction = np.where(np.isfinite(correction), correction, 0.0)
        else:
            correction = np.zeros(merged.shape[0], dtype=float)
            coeffs_list = [0.0] * len(available_pc_cols)

        merged["pc_drift_component"] = correction
        merged["prediction_corrected"] = merged["prediction"] - correction

        payload["merged_corrected"] = merged
        payload["pc_coeffs"] = coeffs_list
        payload["correction_applied"] = correction_applied
        category_store[cat] = payload

        trend_before = compute_linear_trend(merged["start_sec"], merged["prediction"])
        trend_after = compute_linear_trend(merged["start_sec"], merged["prediction_corrected"])

        correction_reports.append(
            {
                "category": cat,
                "pcs_used": ",".join(available_pc_cols) if available_pc_cols else "(none)",
                "coefficients": coeffs_list,
                "correction_applied": correction_applied,
                "slope_before_per_min": trend_before["slope_per_min"],
                "slope_after_per_min": trend_after["slope_per_min"],
                "delta_before": trend_before["delta"],
                "delta_after": trend_after["delta"],
            }
        )

        if SAVE_CORRECTED:
            out_df = payload["raw"].copy()
            out_df["prediction_corrected"] = merged["prediction_corrected"]
            out_path = payload["csv_path"].with_name(
                payload["csv_path"].stem + f"_{CORRECTED_SUFFIX}.csv"
            )
            out_df.to_csv(out_path, index=False)
            payload["corrected_path"] = out_path
            category_store[cat] = payload
            print(f"Saved corrected predictions for '{cat}' → {out_path}")

    if correction_reports:
        display(pd.DataFrame(correction_reports))
    else:
        print("No correction reports generated.")

category_summary_updated = build_category_summary(category_store, pc_scores.shape[1])
if not category_summary_updated.empty:
    display(category_summary_updated)


In [None]:
# ==== visualize all corrected categories ====
corrected_entries = sorted(
    [
        (cat, payload)
        for cat, payload in category_store.items()
        if "merged_corrected" in payload and "prediction_corrected" in payload["merged_corrected"]
    ],
    key=lambda item: item[0],
)

if not corrected_entries:
    print("No corrected categories available. Run the correction cell first or adjust PC_CORRECTION_TARGETS.")
else:
    n_panels = len(corrected_entries)
    fig, axes = plt.subplots(n_panels, 1, figsize=(12, max(3.0 * n_panels, 3.5)), sharex=False)
    if n_panels == 1:
        axes = [axes]
    for ax, (cat, payload) in zip(axes, corrected_entries):
        merged = payload["merged_corrected"]
        ax.plot(merged["start_sec"], merged["target"], label="Target", linewidth=1.2, color="#1f77b4")
        ax.plot(
            merged["start_sec"],
            merged["prediction"],
            label="Prediction",
            linewidth=1.0,
            color="#d62728",
            alpha=0.6,
        )
        ax.plot(
            merged["start_sec"],
            merged["prediction_corrected"],
            label="Prediction (corrected)",
            linewidth=1.1,
            linestyle="--",
            color="#2ca02c",
        )
        slope_before = compute_linear_trend(merged["start_sec"], merged["prediction"])
        slope_after = compute_linear_trend(merged["start_sec"], merged["prediction_corrected"])
        status = "applied" if payload.get("correction_applied", False) else "copied"
        ax.set_title(
            f"{cat} – slope/min before={slope_before['slope_per_min']:.4f}, after={slope_after['slope_per_min']:.4f} ({status})"
        )
        ax.set_xlabel("Start time (s)")
        ax.set_ylabel("Value")
        ax.grid(alpha=0.3)
        ax.legend(loc="upper left")
    fig.tight_layout()
    plt.show()
