# SPEAR Results Analysis

Single-cell Prediction of gene Expression from ATAC-seq Regression.


This notebook generates SPEAR results figures, diagnostics, and export-ready tables for model comparison.


## Prereqs

- Place run outputs under `output/results/spear_results/` (one subfolder per run with `models/` and metrics CSVs).

- Ensure logs exist under `output/logs/` with `spear_*` naming if you want resource plots.

- Keep `analysis/model_name_lookup.tsv` present (already tracked in repo).

- Install dependencies (see README) and select the `spear_env` kernel.

### How to run

1. Open this notebook inside the repo root.

2. Adjust the run include globs if you want to filter; otherwise leave defaults.

3. Run all cells top-to-bottom after outputs are in place.


## 1. Environment Setup

Imports, plotting defaults, and global configuration.


In [None]:
from __future__ import annotations

from dataclasses import dataclass, replace
from datetime import datetime
from fnmatch import fnmatch
from pathlib import Path
import sys
project_root = Path.cwd().resolve()
while project_root.name in {'analysis', 'scripts'}:
    project_root = project_root.parent
src_root = project_root / 'src'
for candidate in (src_root, project_root):
    if str(candidate) not in sys.path:
        sys.path.insert(0, str(candidate))

from typing import Iterable, Optional, Sequence

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from IPython.display import Markdown, display, Image

# Configure plotting defaults for consistent styling
sns.set_theme(style="whitegrid", context="paper")
sns.set_context("paper", font_scale=1.1)
plt.rcParams.update({"figure.dpi": 160, "savefig.dpi": 320})
pd.options.display.max_columns = 120
pd.options.display.width = 180

## 2. Analysis Configuration

Project paths, figure/report locations, and default analysis parameters.


In [None]:
# Centralised configuration for the notebook run
@dataclass
class AnalysisConfig:
    project_root: Path
    results_root: Path
    lookup_path: Path
    fig_dir: Path
    reports_dir: Path
    run_include_globs: tuple[str, ...] = ("*",)
    run_exclude: tuple[str, ...] = tuple()
    primary_split: str = "test"
    val_split: str = "val"
    train_split: str = "train"
    top_gene_count: int = 15
    top_model_count: int = 3
    random_seed: int = 7

    def __post_init__(self) -> None:
        """Create required directories if they are missing."""
        self.fig_dir.mkdir(parents=True, exist_ok=True)
        self.reports_dir.mkdir(parents=True, exist_ok=True)
        self.lookup_path.parent.mkdir(parents=True, exist_ok=True)


project_root = Path.cwd().resolve()
while project_root.name in {"analysis", "scripts"}:
    project_root = project_root.parent

config = AnalysisConfig(
    project_root=project_root,
    results_root=project_root / "output" / "results" / "spear_results",
    lookup_path=project_root / "analysis" / "model_name_lookup.tsv",
    fig_dir=project_root / "analysis" / "figs" / "spear_results_analysis",
    reports_dir=project_root / "analysis" / "reports",
    run_include_globs=("*",),
    run_exclude=tuple(),
    random_seed=7,
    top_gene_count=15,
    top_model_count=3,
 )

if not config.results_root.exists():
    raise FileNotFoundError(f"Results directory missing: {config.results_root}")
if not config.lookup_path.exists():
    # Seed the lookup table if it is missing so later steps can append to it.
    pd.DataFrame({
        "model_id": [],
        "model_display_name": [],
        "model_short_name": [],
    }).to_csv(
        config.lookup_path, sep="\t", index=False
    )

np.random.seed(config.random_seed)
rng = np.random.default_rng(config.random_seed)
FIGURES: dict[str, plt.Figure] = {}
TABLES: dict[str, pd.DataFrame] = {}

# Dataset label appended to plot titles when available.
dataset_title_suffix: str | None = None
analysis_metadata: dict[str, object] = {
    "generated_at": datetime.now().isoformat(timespec="seconds"),
    "project_root": config.project_root,
    "results_root": config.results_root,
}

## 2a. Optional Run Selection Overrides

Use these lists to focus the notebook on a subset of runs.


In [None]:
# Quick toggles (edit these only)
# - TARGET_DATASET: switch between embryonic/endothelial
# - USE_1000PLUS100: include 1000-gene runs with 100-gene fallback
# - PREFER_1000_FALLBACK_100: keep most recent 1000-gene per model, fallback to 100-gene only if missing
TARGET_DATASET: str = "embryonic"  # "embryonic" or "endothelial"
USE_1000PLUS100: bool = True  # True = 1000 genes with 100 genes fallback
PREFER_1000_FALLBACK_100: bool = True  # de-duplicate by model

# Populate one or both of the lists below to narrow the analysis scope.
RUN_DIRECTORY_SELECTION: list[str | Path] = [
    # Examples:
    # "spear_1000genes_k5_pg20_20251108_xgboost",
    # Path("output/results/spear_results/spear_1000genes_k5_pg20_20251108_random_forest"),
]
RUN_GLOB_SELECTION: list[str] = []

if USE_1000PLUS100:
    RUN_GLOB_SELECTION = [
        f"*{TARGET_DATASET}*1000genes*",
        f"*{TARGET_DATASET}*100genes*",
    ]
    RUN_SUBSET_LABEL: Optional[str] = f"{TARGET_DATASET}_1000plus100"
    RUN_SUBSET_DESCRIPTION: Optional[str] = f"{TARGET_DATASET.title()} | 1000 genes (+ 100 genes backup)"
else:
    RUN_GLOB_SELECTION = [
        f"*{TARGET_DATASET}*1000genes*",
    ]
    RUN_SUBSET_LABEL: Optional[str] = f"{TARGET_DATASET}_1000genes"
    RUN_SUBSET_DESCRIPTION: Optional[str] = f"{TARGET_DATASET.title()} | 1000 genes"


# --- Do not edit below unless you want to change the filtering logic itself. ---
# --- Do not edit below unless you want to change the filtering logic itself. ---



## 3. Run Discovery Utilities

Helpers for locating model outputs and attaching display metadata.


In [None]:
# Canonical representation of a single trained model output folder
@dataclass(frozen=True)
class RunRecord:
    run_name: str
    model_id: str
    run_path: Path
    model_path: Path
    metrics_path: Optional[Path]
    predictions_path: Optional[Path]
    training_history_path: Optional[Path]
    model_display: Optional[str] = None


LOOKUP_SPECIAL_CASES = {
    "cnn": "Convolutional Neural Network",
    "rnn": "Recurrent Neural Network",
    "lstm": "Long Short-Term Memory",
    "mlp": "Multilayer Perceptron",
    "svr": "Support Vector Regressor",
    "ols": "Ordinary Least Squares",
    "xgboost": "XGBoost",
    "catboost": "CatBoost",
    "hist_gradient_boosting": "Histogram Gradient Boosting",
    "extra_trees": "Extra Trees",
    "random_forest": "Random Forest",
    "elastic_net": "Elastic Net",
}

SHORT_NAME_FALLBACKS = {
    "Multilayer Perceptron": "MLP",
    "Graph Neural Network": "GNN",
    "Convolutional Neural Network": "CNN",
    "Long Short-Term Memory Network": "LSTM",
    "Recurrent Neural Network": "RNN",
    "Transformer Encoder": "Transformer",
    "Ordinary Least Squares": "OLS",
    "Extra Trees": "Extra Trees",
    "Random Forest": "Random Forest",
    "Ridge Regression": "Ridge",
}

MODEL_ID_TO_DISPLAY: dict[str, str] = {}
MODEL_ID_TO_SHORT: dict[str, str] = {}
MODEL_DISPLAY_TO_SHORT: dict[str, str] = {}


def _default_short_name(display_name: str) -> str:
    """Generate a lightweight abbreviation when none is provided."""
    if not isinstance(display_name, str) or not display_name.strip():
        return ""
    tokens = [token for token in display_name.replace("(", " ").replace(")", " ").split() if token]
    if not tokens:
        return display_name
    acronym = "".join(token[0] for token in tokens if token and token[0].isalnum()).upper()
    if 1 < len(acronym) <= 5:
        return acronym
    return display_name


def compute_heatmap_limits(
    values: pd.DataFrame | np.ndarray,
    lower_percentile: float = 5.0,
    upper_percentile: float = 95.0,
    clip: tuple[float, float] = (0.0, 1.0),
    min_buffer: float = 0.01,
) -> tuple[float, float]:
    # Derive consistent vmin/vmax bounds so heatmaps emphasise the dense value range.
    data = np.asarray(values, dtype=float)
    data = data[np.isfinite(data)]
    if data.size == 0:
        return clip
    lower = np.percentile(data, lower_percentile)
    upper = np.percentile(data, upper_percentile)
    buffer = max(min_buffer, (upper - lower) * 0.05)
    vmin = max(clip[0], lower - buffer)
    vmax = min(clip[1], upper + buffer)
    if vmin > vmax:
        if clip[0] <= clip[1]:
            vmin, vmax = clip
        else:
            vmin, vmax = float(data.min()), float(data.max())
    if np.isclose(vmin, vmax):
        spread = max(min_buffer, abs(vmin) * 0.1 or min_buffer)
        vmin -= spread
        vmax += spread
        vmin = max(clip[0], vmin)
        vmax = min(clip[1], vmax)
        if vmin > vmax:
            vmin, vmax = float(data.min()), float(data.max())
    return vmin, vmax


def to_short_name(name: str | None) -> str:
    """Return a concise display name for figure titles and filenames."""
    if not name:
        return ""
    if name in MODEL_DISPLAY_TO_SHORT:
        return MODEL_DISPLAY_TO_SHORT[name]
    return SHORT_NAME_FALLBACKS.get(name, name)


def _read_lookup_table(path: Path) -> pd.DataFrame:
    path = Path(path)
    if not path.exists():
        return pd.DataFrame({
            "model_id": pd.Series(dtype="string"),
            "model_display_name": pd.Series(dtype="string"),
            "model_short_name": pd.Series(dtype="string"),
        })
    df = pd.read_csv(path, sep="\t")
    expected = {"model_id", "model_display_name"}
    missing_cols = expected.difference(df.columns)
    if missing_cols:
        raise ValueError(f"Lookup table missing required columns: {sorted(missing_cols)}")
    if "model_short_name" not in df.columns:
        df["model_short_name"] = df["model_display_name"].map(_default_short_name)
    else:
        df["model_short_name"] = df["model_short_name"].fillna("")
        missing_short = df["model_short_name"].str.strip() == ""
        df.loc[missing_short, "model_short_name"] = df.loc[missing_short, "model_display_name"].map(_default_short_name)
    return df


def _update_model_name_maps(df: pd.DataFrame) -> None:
    """Cache name lookups for downstream plotting helpers."""
    global MODEL_ID_TO_DISPLAY, MODEL_ID_TO_SHORT, MODEL_DISPLAY_TO_SHORT
    if df.empty:
        MODEL_ID_TO_DISPLAY = {}
        MODEL_ID_TO_SHORT = {}
        MODEL_DISPLAY_TO_SHORT = {}
        return
    standardised = df.fillna("")
    MODEL_ID_TO_DISPLAY = {
        row.model_id: row.model_display_name or _guess_display_name(row.model_id)
        for row in standardised.itertuples(index=False)
    }
    MODEL_ID_TO_SHORT = {
        row.model_id: (row.model_short_name or MODEL_ID_TO_DISPLAY[row.model_id])
        for row in standardised.itertuples(index=False)
    }
    MODEL_DISPLAY_TO_SHORT = {
        MODEL_ID_TO_DISPLAY[row.model_id]: MODEL_ID_TO_SHORT[row.model_id]
        for row in standardised.itertuples(index=False)
    }


def _guess_display_name(model_id: str) -> str:
    if model_id in LOOKUP_SPECIAL_CASES:
        return LOOKUP_SPECIAL_CASES[model_id]
    parts = [part for part in model_id.replace("-", " ").replace("_", " ").split(" ") if part]
    if not parts:
        return model_id
    formatted = []
    for token in parts:
        if len(token) <= 3:
            formatted.append(token.upper())
        else:
            formatted.append(token.capitalize())
    return " ".join(formatted)


def _matches_any(value: str, patterns: Iterable[str]) -> bool:
    return any(fnmatch(value, pattern) for pattern in patterns) if patterns else False


def _first_existing(path: Path, candidates: Sequence[str]) -> Optional[Path]:
    for name in candidates:
        candidate = path / name
        if candidate.exists():
            return candidate
    return None


def discover_model_runs(
    results_root: Path,
    include_globs: Iterable[str],
    exclude_patterns: Iterable[str],
) -> list[RunRecord]:
    results_root = Path(results_root)
    if not results_root.exists():
        raise FileNotFoundError(f"Results root missing: {results_root}")
    records: list[RunRecord] = []
    include = tuple(include_globs) if include_globs else ("*",)
    exclude = tuple(exclude_patterns) if exclude_patterns else tuple()
    for models_dir in sorted(results_root.rglob("models")):
        if not models_dir.is_dir():
            continue
        run_dir = models_dir.parent
        run_name = run_dir.name
        if not _matches_any(run_name, include):
            continue
        if exclude and _matches_any(run_name, exclude):
            continue
        for model_dir in sorted(models_dir.iterdir()):
            if not model_dir.is_dir():
                continue
            model_id = model_dir.name
            metrics_path = _first_existing(
                model_dir,
                (
                    "metrics_per_gene.csv",
                    "metrics_by_gene.csv",
                    "metrics_cv.csv",
                ),
            )
            predictions_path = _first_existing(
                model_dir,
                (
                    "predictions_raw.csv",
                    "predictions.csv",
                ),
            )
            history_path = _first_existing(
                model_dir,
                (
                    "training_history.csv",
                    "training_history_loss.csv",
                ),
            )
            records.append(
                RunRecord(
                    run_name=run_name,
                    model_id=model_id,
                    run_path=run_dir,
                    model_path=model_dir,
                    metrics_path=metrics_path,
                    predictions_path=predictions_path,
                    training_history_path=history_path,
                )
            )
    return records


def ensure_model_lookup(path: Path, model_ids: Iterable[str]) -> pd.DataFrame:
    path = Path(path)
    df = _read_lookup_table(path)
    existing = set(df["model_id"]) if not df.empty else set()
    new_rows = []
    for model_id in sorted(set(model_ids).difference(existing)):
        display_name = _guess_display_name(model_id)
        short_name = SHORT_NAME_FALLBACKS.get(display_name, _default_short_name(display_name))
        new_rows.append(
            {
                "model_id": model_id,
                "model_display_name": display_name,
                "model_short_name": short_name,
            }
        )
    if new_rows:
        additions = pd.DataFrame(new_rows)
        df = pd.concat([df, additions], ignore_index=True) if not df.empty else additions
        df.sort_values("model_id", inplace=True)
        df.to_csv(path, sep="\t", index=False)
    return df


def attach_lookup(records: Sequence[RunRecord], model_lookup: pd.DataFrame) -> list[RunRecord]:
    if model_lookup.empty:
        return list(records)
    display_map = dict(zip(model_lookup["model_id"], model_lookup["model_display_name"]))
    resolved: list[RunRecord] = []
    for record in records:
        display = display_map.get(record.model_id, _guess_display_name(record.model_id))
        resolved.append(replace(record, model_display=display))
    return resolved


def run_records_to_frame(records: Sequence[RunRecord]) -> pd.DataFrame:
    if not records:
        return pd.DataFrame(
            columns=[
                "run_name",
                "model_id",
                "model_display",
                "run_path",
                "model_path",
                "metrics_path",
                "predictions_path",
                "training_history_path",
            ]
        )
    data = [
        {
            "run_name": r.run_name,
            "model_id": r.model_id,
            "model_display": r.model_display or _guess_display_name(r.model_id),
            "run_path": r.run_path,
            "model_path": r.model_path,
            "metrics_path": r.metrics_path,
            "predictions_path": r.predictions_path,
            "training_history_path": r.training_history_path,
        }
        for r in records
    ]
    return pd.DataFrame(data)


def to_relative_path(path_like: Optional[Path], root: Path) -> Optional[str]:
    if path_like is None:
        return None
    path = Path(path_like)
    try:
        return str(path.resolve().relative_to(root))
    except Exception:
        return str(path.resolve())


def maybe_store_table(store: dict[str, pd.DataFrame], key: str, table: pd.DataFrame) -> None:
    if table is None or table.empty:
        return
    store[key] = table


def _format_title_with_dataset(title: str | None, suffix: str | None) -> str | None:
    if not title or not suffix:
        return title
    if suffix in title:
        return title
    return f"{title} | {suffix}"


def register_figure(store: dict[str, object], key: str, fig: Optional[plt.Figure]) -> None:
    """Track generated matplotlib figures for later export."""
    if fig is None:
        store.pop(key, None)
        return
    if dataset_title_suffix:
        for ax in fig.axes:
            current = ax.get_title()
            updated = _format_title_with_dataset(current, dataset_title_suffix)
            if updated and updated != current:
                ax.set_title(updated)
        if fig._suptitle is not None:
            current = fig._suptitle.get_text()
            updated = _format_title_with_dataset(current, dataset_title_suffix)
            if updated and updated != current:
                fig._suptitle.set_text(updated)
    store[key] = fig


def load_metrics(records: Sequence[RunRecord]) -> tuple[pd.DataFrame, pd.DataFrame]:
    long_frames: list[pd.DataFrame] = []
    for record in records:
        metrics_path = record.metrics_path
        if metrics_path is None or not metrics_path.exists():
            continue
        try:
            df = pd.read_csv(metrics_path)
        except Exception as exc:
            print(f"Failed to load metrics from {metrics_path}: {exc}")
            continue
        required_cols = {"gene", "split", "pearson"}
        if not required_cols.issubset(df.columns):
            continue
        df = df.copy()
        df["run_name"] = record.run_name
        df["model_id"] = record.model_id
        df["model_display"] = record.model_display or _guess_display_name(record.model_id)
        long_frames.append(df)
    if not long_frames:
        return pd.DataFrame(), pd.DataFrame()
    metrics_long = pd.concat(long_frames, ignore_index=True)
    base_cols = [col for col in ["run_name", "model_id", "model_display", "gene", "split"] if col in metrics_long.columns]
    metric_cols = [col for col in metrics_long.columns if col not in base_cols]
    metrics_long = metrics_long[base_cols + metric_cols]

    wide = metrics_long.pivot_table(
        index=["run_name", "model_id", "model_display", "gene"],
        columns="split",
        values="pearson",
    )
    wide.columns = [f"{str(col).lower()}_pearson" for col in wide.columns]
    metrics_wide = wide.reset_index()

    return metrics_long, metrics_wide


def compute_model_summary(
    metrics_wide: pd.DataFrame,
    splits: Sequence[str],
) -> pd.DataFrame:
    if metrics_wide.empty:
        return pd.DataFrame(
            columns=[
                "model_display",
                "model_id",
                "run_name",
                *[f"{split}_pearson_mean" for split in splits],
                *[f"{split}_pearson_std" for split in splits],
            ]
        )
    summaries = []
    lower_splits = [split.lower() for split in splits]
    for (run_name, model_id, model_display), group in metrics_wide.groupby(["run_name", "model_id", "model_display"], dropna=False):
        row = {
            "run_name": run_name,
            "model_id": model_id,
            "model_display": model_display,
        }
        for split, lower in zip(splits, lower_splits):
            column = f"{lower}_pearson"
            if column in group:
                values = group[column].dropna()
                if not values.empty:
                    row[f"{split}_pearson_mean"] = values.mean()
                    row[f"{split}_pearson_std"] = values.std(ddof=1) if len(values) > 1 else float("nan")
        summaries.append(row)
    summary_df = pd.DataFrame(summaries)
    if "test_pearson_mean" in summary_df:
        summary_df.sort_values("test_pearson_mean", ascending=False, inplace=True)
    summary_df.set_index(["model_display", "model_id", "run_name"], inplace=True)
    return summary_df

## 4. Discover and Inspect Run Metadata

Enumerate available runs and assemble the run metadata table.


In [None]:
import re
raw_run_records = discover_model_runs(
    config.results_root,
    config.run_include_globs,
    config.run_exclude
)
model_lookup = ensure_model_lookup(
    config.lookup_path,
    [record.model_id for record in raw_run_records]
)
_update_model_name_maps(model_lookup)
run_records = attach_lookup(raw_run_records, model_lookup)


def _dataset_from_run_name(name: str) -> str | None:
    name = str(name).lower()
    if "embryonic" in name:
        return "embryonic"
    if "endothelial" in name:
        return "endothelial"
    return None


def _gene_count_from_run_name(name: str) -> int | None:
    match = re.search(r"(\d+)genes", str(name).lower())
    if not match:
        return None
    try:
        return int(match.group(1))
    except ValueError:
        return None


def _run_timestamp_key(name: str) -> tuple[int, int]:
    match = re.search(r"_(\d{8})_(\d{6})", str(name))
    if not match:
        return (0, 0)
    return (int(match.group(1)), int(match.group(2)))


def _has_metrics(rec: RunRecord) -> bool:
    path = rec.metrics_path
    if not path:
        return False
    path = path if isinstance(path, Path) else Path(path)
    return path.exists()

def _model_dir_has_files(rec: RunRecord) -> bool:
    path = rec.model_path
    if not path:
        return False
    path = path if isinstance(path, Path) else Path(path)
    try:
        return path.exists() and any(path.iterdir())
    except Exception:
        return False


def _is_nonempty_record(rec: RunRecord) -> bool:
    return _has_metrics(rec) or _model_dir_has_files(rec)



def _select_latest_runs(records: list[RunRecord]) -> list[RunRecord]:
    if not records:
        return []
    grouped: dict[tuple[str, int, str], list[RunRecord]] = {}
    passthrough: list[RunRecord] = []
    for rec in records:
        dataset = _dataset_from_run_name(rec.run_name)
        gene_count = _gene_count_from_run_name(rec.run_name)
        if not dataset or gene_count is None:
            passthrough.append(rec)
            continue
        key = (dataset, gene_count, rec.model_id)
        grouped.setdefault(key, []).append(rec)

    selected: list[RunRecord] = []
    for group in grouped.values():
        with_metrics = [rec for rec in group if _has_metrics(rec)]
        if not with_metrics:
            continue
        with_metrics.sort(key=lambda rec: _run_timestamp_key(rec.run_name), reverse=True)
        selected.append(with_metrics[0])

    return selected + passthrough


run_records = _select_latest_runs(list(run_records))

# Fallback: if a 1000-gene run is missing metrics, try archive 1000-gene runs first,
# then fall back to 100-gene runs when needed.
if USE_1000PLUS100:
    archive_root = config.results_root / "archive"
    if archive_root.exists():
        successful_1000: set[tuple[str, str]] = set()
        for record in raw_run_records:
            if "1000genes" not in record.run_name:
                continue
            metrics_path = record.metrics_path
            if metrics_path is None or not Path(metrics_path).exists():
                continue
            ds = _dataset_from_run_name(record.run_name) or ""
            if ds:
                successful_1000.add((ds, record.model_id))

        for run_dir in archive_root.iterdir():
            if not run_dir.is_dir() or "1000genes" not in run_dir.name:
                continue
            ds = _dataset_from_run_name(run_dir.name) or ""
            if not ds:
                continue
            models_dir = run_dir / "models"
            if not models_dir.exists():
                continue
            for model_dir in models_dir.iterdir():
                if not model_dir.is_dir():
                    continue
                metrics_path = _first_existing(model_dir, ("metrics_per_gene.csv", "metrics_by_gene.csv", "metrics_cv.csv"))
                if metrics_path and metrics_path.exists():
                    successful_1000.add((ds, model_dir.name))

        updated_records: list[RunRecord | None] = list(run_records)
        for idx, rec in enumerate(list(run_records)):
            metrics_path = rec.metrics_path
            metrics_missing = metrics_path is None or not Path(metrics_path).exists()
            if not metrics_missing:
                continue
            if "1000genes" not in rec.run_name:
                continue

            dataset_label = _dataset_from_run_name(rec.run_name) or ""
            alt_model_dir = None
            alt_metrics = None
            alt_run_dir = None

            candidates = [d for d in archive_root.iterdir() if d.is_dir() and "1000genes" in d.name]
            if dataset_label:
                candidates = [d for d in candidates if dataset_label.lower() in d.name.lower()]
            for cand in sorted(candidates, key=lambda d: _run_timestamp_key(d.name), reverse=True):
                model_dir = cand / "models" / rec.model_id
                if not model_dir.exists():
                    continue
                alt_metrics = _first_existing(model_dir, ("metrics_per_gene.csv", "metrics_by_gene.csv", "metrics_cv.csv"))
                if alt_metrics and alt_metrics.exists():
                    alt_run_dir = cand
                    alt_model_dir = model_dir
                    break

            if alt_run_dir is None and (dataset_label, rec.model_id) not in successful_1000:
                candidates = [d for d in archive_root.iterdir() if d.is_dir() and "100genes" in d.name]
                if dataset_label:
                    candidates = [d for d in candidates if dataset_label.lower() in d.name.lower()]
                for cand in sorted(candidates, key=lambda d: _run_timestamp_key(d.name), reverse=True):
                    model_dir = cand / "models" / rec.model_id
                    if not model_dir.exists():
                        continue
                    alt_metrics = _first_existing(model_dir, ("metrics_per_gene.csv", "metrics_by_gene.csv", "metrics_cv.csv"))
                    if alt_metrics and alt_metrics.exists():
                        alt_run_dir = cand
                        alt_model_dir = model_dir
                        break

            if alt_metrics and alt_metrics.exists() and alt_model_dir and alt_model_dir.exists():
                alt_preds = _first_existing(
                    alt_model_dir,
                    (
                        "predictions_raw.csv",
                        "predictions.csv",
                    ),
                )
                alt_history = _first_existing(
                    alt_model_dir,
                    (
                        "training_history.csv",
                        "training_history_loss.csv",
                    ),
                )
                new_rec = RunRecord(
                    run_name=alt_run_dir.name,
                    model_id=rec.model_id,
                    run_path=alt_run_dir,
                    model_path=alt_model_dir,
                    metrics_path=alt_metrics,
                    predictions_path=alt_preds,
                    training_history_path=alt_history,
                    model_display=rec.model_display,
                )
                updated_records[idx] = new_rec

        run_records = [r for r in updated_records if r is not None]

# Add archive 100-gene runs when 1000-gene models are missing (per dataset)
if USE_1000PLUS100:
    archive_root = config.results_root / "archive"
    if archive_root.exists():
        successful_1000: set[tuple[str, str]] = set()
        for record in raw_run_records:
            if "1000genes" not in record.run_name:
                continue
            metrics_path = record.metrics_path
            if metrics_path is None or not Path(metrics_path).exists():
                continue
            ds = _dataset_from_run_name(record.run_name) or ""
            if ds:
                successful_1000.add((ds, record.model_id))

        for run_dir in archive_root.iterdir():
            if not run_dir.is_dir() or "1000genes" not in run_dir.name:
                continue
            ds = _dataset_from_run_name(run_dir.name) or ""
            if not ds:
                continue
            models_dir = run_dir / "models"
            if not models_dir.exists():
                continue
            for model_dir in models_dir.iterdir():
                if not model_dir.is_dir():
                    continue
                metrics_path = _first_existing(model_dir, ("metrics_per_gene.csv", "metrics_by_gene.csv", "metrics_cv.csv"))
                if metrics_path and metrics_path.exists():
                    successful_1000.add((ds, model_dir.name))

        dataset_models_1000 = {}
        for rec in run_records:
            if "1000genes" not in rec.run_name:
                continue
            ds = "embryonic" if "embryonic" in rec.run_name.lower() else "endothelial" if "endothelial" in rec.run_name.lower() else ""
            dataset_models_1000.setdefault(ds, set()).add(rec.model_id)

        archive_candidates = {}
        for run_dir in archive_root.iterdir():
            if not run_dir.is_dir() or "100genes" not in run_dir.name:
                continue
            ds = "embryonic" if "embryonic" in run_dir.name.lower() else "endothelial" if "endothelial" in run_dir.name.lower() else ""
            if not ds:
                continue
            models_dir = run_dir / "models"
            if not models_dir.exists():
                continue
            for model_dir in models_dir.iterdir():
                if not model_dir.is_dir():
                    continue
                metrics_path = _first_existing(model_dir, ("metrics_per_gene.csv", "metrics_by_gene.csv", "metrics_cv.csv"))
                if not metrics_path or not metrics_path.exists():
                    continue
                key = (ds, model_dir.name)
                archive_candidates.setdefault(key, []).append((run_dir, model_dir, metrics_path))

        def _timestamp_key(name: str) -> tuple[int, int]:
            import re
            match = re.search(r"_(\d{8})_(\d{6})", name)
            if not match:
                return (0, 0)
            return (int(match.group(1)), int(match.group(2)))

        added_records = []
        for (ds, model_id), candidates in archive_candidates.items():
            if model_id in dataset_models_1000.get(ds, set()):
                continue
            candidates.sort(key=lambda item: _timestamp_key(item[0].name), reverse=True)
            run_dir, model_dir, metrics_path = candidates[0]
            preds_path = _first_existing(model_dir, ("predictions_raw.csv", "predictions.csv"))
            history_path = _first_existing(model_dir, ("training_history.csv", "training_history_loss.csv"))
            new_rec = RunRecord(
                run_name=run_dir.name,
                model_id=model_id,
                run_path=run_dir,
                model_path=model_dir,
                metrics_path=metrics_path,
                predictions_path=preds_path,
                training_history_path=history_path,
                model_display=MODEL_ID_TO_DISPLAY.get(model_id, _guess_display_name(model_id)),
            )
            added_records.append(new_rec)
        if added_records:
            run_records = list(run_records) + added_records


def _select_preferred_runs(records: list[RunRecord], prefer_1000: bool) -> list[RunRecord]:
    # Pick the most recent 1000-gene run per model/dataset, falling back to 100-gene runs when missing.
    if not records:
        return []
    if not prefer_1000:
        return list(records)

    selected: list[RunRecord] = []
    by_model_dataset: dict[tuple[str, str], list[RunRecord]] = {}
    for record in records:
        dataset_label = _dataset_from_run_name(record.run_name) or ""
        key = (record.model_id, dataset_label)
        by_model_dataset.setdefault(key, []).append(record)

    for (model_id, dataset_label), group in by_model_dataset.items():
        group_1000 = [rec for rec in group if "1000genes" in rec.run_name and _has_metrics(rec)]
        group_100 = [rec for rec in group if "100genes" in rec.run_name and _has_metrics(rec)]
        group_1000_all = [rec for rec in group if "1000genes" in rec.run_name]
        group_100_all = [rec for rec in group if "100genes" in rec.run_name]
        candidates = group_1000 or group_100 or group_1000_all or group_100_all or group
        candidates.sort(key=lambda rec: _run_timestamp_key(rec.run_name), reverse=True)
        selected.append(candidates[0])

    return selected

run_records = _select_preferred_runs(list(run_records), PREFER_1000_FALLBACK_100)
run_df = run_records_to_frame(run_records)
run_df.sort_values(["run_name", "model_id"], inplace=True)
run_df_display = run_df.copy()
for column in ("run_path", "model_path", "metrics_path", "predictions_path", "training_history_path"):
    run_df_display[column] = run_df_display[column].map(lambda value: to_relative_path(value, config.project_root))
analysis_metadata.update({
    "results_root": to_relative_path(config.results_root, config.project_root),
    "fig_dir": to_relative_path(config.fig_dir, config.project_root),
    "reports_dir": to_relative_path(config.reports_dir, config.project_root),
    "run_count": run_df["run_name"].nunique(),
    "model_count": len(run_df),
    "model_lookup_path": to_relative_path(config.lookup_path, config.project_root),
})
analysis_metadata.setdefault("include_globs", config.run_include_globs)
display(Markdown(f"**Scanning results root:** `{analysis_metadata['results_root']}`"))
subset_descriptor = analysis_metadata.get("subset_descriptor")
if subset_descriptor:
    display(Markdown(f"**Subset criteria:** {subset_descriptor}"))
include_filters = analysis_metadata.get("include_globs")
if include_filters:
    include_text = ', '.join(str(item) for item in include_filters)
    display(Markdown(f"**Include filters:** `{include_text}`"))
display(Markdown(
    f"**Figure output:** `{analysis_metadata['fig_dir']}` | **Reports:** `{analysis_metadata['reports_dir']}`"
))
display(run_df_display)
print(
    "Discovered",
    analysis_metadata["model_count"],
    "model folders across",
    analysis_metadata["run_count"],
    "runs.",
)


## 5. Load Metrics and Compute Summaries

Load per-gene metrics and compute split-level summaries.


In [None]:
metrics_long, metrics_wide = load_metrics(run_records)

# Table: mean test Pearson for all runs with metrics
all_records = attach_lookup(raw_run_records, model_lookup)
all_metrics_long, all_metrics_wide = load_metrics(all_records)
if all_metrics_wide.empty or "test_pearson" not in all_metrics_wide:
    print("No test-set metrics available for full run table.")
else:
    mean_df = (
        all_metrics_wide.groupby(["run_name", "model_id"])
        ["test_pearson"]
        .mean()
        .reset_index()
        .rename(columns={"test_pearson": "test_pearson_mean"})
    )
    path_df = run_records_to_frame(all_records)[["run_name", "model_id", "model_path"]]
    all_run_mean_test_pearson = mean_df.merge(path_df, on=["run_name", "model_id"], how="left")
    all_run_mean_test_pearson["dataset"] = all_run_mean_test_pearson["run_name"].map(_dataset_from_run_name)
    all_run_mean_test_pearson["model_path"] = all_run_mean_test_pearson["model_path"].map(
        lambda value: to_relative_path(value, config.project_root)
    )
    display(all_run_mean_test_pearson.sort_values(["run_name", "model_id"]))
    TABLES["all_run_mean_test_pearson"] = all_run_mean_test_pearson
if metrics_long.empty:
    raise RuntimeError("No metrics available for plotting.")

split_filter = {config.primary_split, config.val_split, config.train_split}
available_splits = sorted(metrics_long["split"].unique())
missing_splits = split_filter.difference(available_splits)
if missing_splits:
    print("Warning: the following splits are missing from metrics files:", sorted(missing_splits))

test_metrics = metrics_long[metrics_long["split"] == config.primary_split].copy()
val_metrics = metrics_long[metrics_long["split"] == config.val_split].copy()
train_metrics = metrics_long[metrics_long["split"] == config.train_split].copy()
summary_df = compute_model_summary(
    metrics_wide, [config.primary_split, config.val_split, config.train_split]
)
if summary_df.empty:
    raise RuntimeError("Unable to compute summary statistics from metrics.")

summary_reset = summary_df.reset_index()
analysis_metadata["best_model_id"] = summary_reset.iloc[0]["model_id"]
analysis_metadata["best_model_display"] = summary_reset.iloc[0]["model_display"]
analysis_metadata["best_run_name"] = summary_reset.iloc[0]["run_name"]

if "test_pearson_mean" in summary_reset:
    model_display_order = summary_reset.sort_values(
        by="test_pearson_mean", ascending=False
)[["model_display", "model_id"]].drop_duplicates("model_id")["model_display"].tolist()
else:
    model_display_order = summary_reset["model_display"].tolist()

display(summary_reset)

analysis_state = {
    "run_df": run_df,
    "metrics_long": metrics_long,
    "metrics_wide": metrics_wide,
    "test_metrics": test_metrics,
    "val_metrics": val_metrics,
    "train_metrics": train_metrics,
    "summary_df": summary_df,
    "summary_reset": summary_reset,
    "model_display_order": model_display_order,
    "model_short_name_map": MODEL_ID_TO_SHORT.copy(),
}

## 6. Supporting Tables

Write helper tables that back figures and downstream analysis.


In [None]:
summary_df = analysis_state["summary_df"]
summary_reset = analysis_state["summary_reset"]
metrics_wide = analysis_state["metrics_wide"]
metrics_long = analysis_state["metrics_long"]
test_metrics = analysis_state["test_metrics"]
val_metrics = analysis_state["val_metrics"]
train_metrics = analysis_state["train_metrics"]

val_pearson_per_gene = pd.DataFrame()
if f"{config.val_split}_pearson" in metrics_wide:
    val_pearson_per_gene = metrics_wide.pivot_table(
        index="gene",
        columns="model_display",
        values=f"{config.val_split}_pearson",
    )

test_pearson_per_gene = pd.DataFrame()
if f"{config.primary_split}_pearson" in metrics_wide:
    test_pearson_per_gene = metrics_wide.pivot_table(
        index="gene",
        columns="model_display",
        values=f"{config.primary_split}_pearson",
    )

maybe_store_table(
    TABLES,
    "metrics_per_gene_master",
    metrics_long.sort_values(["split", "run_name", "model_id", "gene"])
)
maybe_store_table(TABLES, "summary_metrics_all_models", summary_reset)

analysis_state.update(
    {
        "val_pearson_per_gene": val_pearson_per_gene,
        "test_pearson_per_gene": test_pearson_per_gene,
        "split_mean_summary": summary_reset,
    }
)

## 7. Test Pearson Heatmap

Model-level test Pearson comparison.


In [None]:
summary_reset = analysis_state["summary_reset"]
model_order = analysis_state["model_display_order"]
if "test_pearson_mean" not in summary_reset:
    print("Test Pearson summary unavailable; skipping heatmap.")
    fig_test_heatmap = None
else:
    ranked = summary_reset.sort_values("test_pearson_mean", ascending=False)
    best_per_model = ranked.drop_duplicates("model_id")
    heatmap_series = best_per_model.set_index("model_display")["test_pearson_mean"]
    heatmap_df = heatmap_series.reindex(model_order).dropna().to_frame(name="Mean Test Pearson")
    if heatmap_df.empty:
        print("No aggregated test Pearson values available; skipping heatmap.")
        fig_test_heatmap = None
    else:
        vmin, vmax = compute_heatmap_limits(heatmap_df.values)
        fig_height = max(4, 0.4 * len(heatmap_df))
        fig_test_heatmap, ax = plt.subplots(figsize=(4.5, fig_height))
        sns.heatmap(
            heatmap_df,
            cmap="viridis",
            vmin=vmin,
            vmax=vmax,
            annot=True,
            fmt=".3f",
            linewidths=0.4,
            linecolor="#f2f2f2",
            cbar_kws={"label": "Mean Test Pearson"},
            ax=ax,
        )
        ax.set_title("Mean Test Pearson by Model")
        ax.set_xlabel("")
        ax.set_ylabel("Model")
        ax.set_xticklabels(ax.get_xticklabels(), rotation=0)
        sns.despine(fig_test_heatmap, left=True, bottom=True)
        plt.tight_layout()
register_figure(FIGURES, "test_pearson_heatmap_all_models", fig_test_heatmap)
if fig_test_heatmap is not None:
    display(fig_test_heatmap)
    plt.close(fig_test_heatmap)

### Optional Validation Heatmap

Runs only if validation Pearson summaries are present.


In [None]:
summary_reset = analysis_state["summary_reset"]
test_pivot = analysis_state.get("test_pearson_per_gene")
test_metrics = analysis_state["test_metrics"]
if test_pivot is None or test_pivot.empty:
    print("Test Pearson data unavailable; skipping heatmap.")
    fig_test_heatmap_top = None
else:
    top_test_genes = (
        test_metrics.groupby("gene")["pearson"].mean().sort_values(ascending=False).head(config.top_gene_count)
    )
    analysis_state["test_top_genes"] = top_test_genes.index.tolist()
    test_top_subset = test_pivot.loc[test_pivot.index.intersection(top_test_genes.index)]
    if test_top_subset.empty:
        print("Top-performing gene subset empty; skipping test heatmap.")
        fig_test_heatmap_top = None
    else:
        model_order = summary_reset[["model_display", "model_id"]].drop_duplicates("model_id")["model_display"].tolist()
        ordered_columns = [col for col in model_order if col in test_top_subset.columns]
        test_top_subset = test_top_subset.reindex(columns=ordered_columns)
        vmin, vmax = compute_heatmap_limits(
            test_top_subset.values, lower_percentile=10.0, upper_percentile=95.0
)
        fig_height_top = max(4, 0.6 * len(test_top_subset.index))
        fig_test_heatmap_top, ax = plt.subplots(figsize=(10, fig_height_top))
        sns.heatmap(
            test_top_subset,
            cmap="crest",
            vmin=vmin,
            vmax=vmax,
            annot=True,
            fmt=".2f",
            linewidths=0.3,
            linecolor="#f5f5f5",
            cbar_kws={"label": "Test Pearson"},
            ax=ax,
        )
        ax.set_title("Test Pearson (top genes by mean across models)")
        ax.set_xlabel("Model")
        ax.set_ylabel("Gene")
        ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right", fontsize=10)
        ax.set_yticklabels(ax.get_yticklabels(), fontsize=10)
        sns.despine(fig_test_heatmap_top, left=True, bottom=True)
        plt.tight_layout()
register_figure(FIGURES, "test_pearson_heatmap_top", fig_test_heatmap_top)
if fig_test_heatmap_top is not None:
    display(fig_test_heatmap_top)
    plt.close(fig_test_heatmap_top)

# --- Top genes per model heatmap (model-specific gene sets) ---
if test_pivot is None or test_pivot.empty:
    fig_test_heatmap_top_per_model = None
else:
    model_order = summary_reset[["model_display", "model_id"]].drop_duplicates("model_id")["model_display"].tolist()
    ordered_columns = [col for col in model_order if col in test_pivot.columns]
    seen_genes: set[str] = set()
    gene_order: list[str] = []
    for model_name in ordered_columns:
        series = test_pivot[model_name].dropna().sort_values(ascending=False).head(config.top_gene_count)
        for gene in series.index:
            if gene not in seen_genes:
                seen_genes.add(gene)
                gene_order.append(gene)
    if not gene_order:
        print("No per-model top genes available; skipping heatmap.")
        fig_test_heatmap_top_per_model = None
    else:
        per_model_subset = test_pivot.reindex(index=gene_order, columns=ordered_columns)
        vmin, vmax = compute_heatmap_limits(
            per_model_subset.values, lower_percentile=10.0, upper_percentile=95.0
        )
        fig_height = max(4, 0.4 * len(per_model_subset.index))
        fig_test_heatmap_top_per_model, ax = plt.subplots(figsize=(10, fig_height))
        sns.heatmap(
            per_model_subset,
            cmap="crest",
            vmin=vmin,
            vmax=vmax,
            annot=True,
            fmt=".2f",
            linewidths=0.3,
            linecolor="#f5f5f5",
            cbar_kws={"label": "Test Pearson"},
            ax=ax,
        )
        ax.set_title("Test Pearson (top genes per model)")
        ax.set_xlabel("Model")
        ax.set_ylabel("Gene")
        ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right", fontsize=10)
        ax.set_yticklabels(ax.get_yticklabels(), fontsize=9)
        sns.despine(fig_test_heatmap_top_per_model, left=True, bottom=True)
        plt.tight_layout()
register_figure(FIGURES, "test_pearson_heatmap_top_per_model", fig_test_heatmap_top_per_model)
if fig_test_heatmap_top_per_model is not None:
    display(fig_test_heatmap_top_per_model)
    plt.close(fig_test_heatmap_top_per_model)




# --- Per-model top-20 heatmap blocks (horizontal layout) ---
if test_pivot is None or test_pivot.empty:
    fig_test_heatmap_top20_blocks = None
else:
    model_order = summary_reset[["model_display", "model_id"]].drop_duplicates("model_id")["model_display"].tolist()
    ordered_columns = [col for col in model_order if col in test_pivot.columns]
    values = {}
    annot = {}
    for model_name in ordered_columns:
        series = test_pivot[model_name].dropna().sort_values(ascending=False).head(20)
        # Pad to 20 rows to keep heatmap rectangular
        padded_values = series.tolist() + [float("nan")] * (20 - len(series))
        padded_genes = series.index.tolist() + [""] * (20 - len(series))
        values[model_name] = padded_values
        annot[model_name] = [f"{g}\n{v:.2f}" if g else "" for g, v in zip(padded_genes, padded_values)]

    if not values:
        print("No per-model top-20 genes available; skipping block heatmap.")
        fig_test_heatmap_top20_blocks = None
    else:
        heatmap_df = pd.DataFrame(values, index=[f"rank {i}" for i in range(1, 21)])
        annot_df = pd.DataFrame(annot, index=heatmap_df.index)
        fig_width = max(8, 1.2 * len(heatmap_df.columns))
        fig_height = 10
        fig_test_heatmap_top20_blocks, ax = plt.subplots(figsize=(fig_width, fig_height))
        sns.heatmap(
            heatmap_df,
            cmap="crest",
            annot=annot_df,
            fmt="",
            linewidths=0.3,
            linecolor="#f5f5f5",
            cbar_kws={"label": "Test Pearson"},
            ax=ax,
        )
        ax.set_title("Test Pearson (top 20 genes per model)")
        ax.set_xlabel("Model")
        ax.set_ylabel("Rank (per model)")
        ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right", fontsize=9)
        ax.set_yticklabels(ax.get_yticklabels(), fontsize=9)
        sns.despine(fig_test_heatmap_top20_blocks, left=True, bottom=True)
        plt.tight_layout()
register_figure(FIGURES, "test_pearson_heatmap_top20_per_model_blocks", fig_test_heatmap_top20_blocks)
if fig_test_heatmap_top20_blocks is not None:
    display(fig_test_heatmap_top20_blocks)
    plt.close(fig_test_heatmap_top20_blocks)

## 8. Test Distribution Profiles

Per-gene test-set Pearson distributions by model.


In [None]:
test_metrics = analysis_state["test_metrics"].copy()
summary_reset = analysis_state["summary_reset"]
model_display_order = analysis_state["model_display_order"]
if test_metrics.empty:
    print("Test metrics unavailable; skipping violin plot.")
    fig_violin = None
else:
    mean_by_model = summary_reset.groupby("model_display")["test_pearson_mean"].mean()
    mean_order_series = mean_by_model.sort_values(ascending=True)
    model_order = [model for model in mean_order_series.index if model in model_display_order]
    if not model_order:
        model_order = model_display_order
    violin_palette = sns.color_palette("Set2", n_colors=len(model_order))
    palette_map = dict(zip(model_order, violin_palette))
    fig_width = max(12, 0.8 * max(6, len(model_order)))
    fig_violin, ax = plt.subplots(figsize=(fig_width, 6))
    sns.violinplot(
        data=test_metrics,
        x="model_display",
        y="pearson",
        hue="model_display",
        order=model_order,
        hue_order=model_order,
        palette=palette_map,
        density_norm="width",
        inner="quartile",
        linewidth=1.0,
        ax=ax,
        legend=False,
    )
    if len(test_metrics) > 0:
        sample_size = min(len(test_metrics), 3000)
        jitter_sample = (
            test_metrics.sample(sample_size, random_state=config.random_seed)
            if len(test_metrics) > sample_size
            else test_metrics
        )
        sns.stripplot(
            data=jitter_sample,
            x="model_display",
            y="pearson",
            order=model_order,
            hue="model_display",
            hue_order=model_order,
            palette=palette_map,
            dodge=False,
            alpha=0.45,
            size=3.0,
            jitter=0.12,
            linewidth=0.3,
            edgecolor="#2b2b2b",
            marker="o",
            ax=ax,
            legend=False,
        )
    metric_min = test_metrics["pearson"].min()
    metric_max = test_metrics["pearson"].max()
    ymin = min(-0.5, metric_min - 0.05)
    ymax = max(1.0, metric_max + 0.05)
    ax.set_ylim(ymin, ymax)
    ax.set_xlabel("Model")
    ax.set_ylabel("Per-gene Test Pearson")
    ax.set_title("Per-gene Test Pearson Distribution by Model")
    ax.axhline(0.0, color="#777777", linestyle="--", linewidth=1)
    plt.setp(ax.get_xticklabels(), rotation=35, ha="right")
    sns.despine(fig_violin, left=True, bottom=True)
    plt.tight_layout()
register_figure(FIGURES, "test_pearson_violin", fig_violin)
if fig_violin is not None:
    display(fig_violin)
    plt.close(fig_violin)


### Split Comparison by Dataset


In [None]:
metrics_long = analysis_state["metrics_long"]
summary_reset = analysis_state["summary_reset"]
model_display_order = analysis_state["model_display_order"]
splits_of_interest = [config.train_split, config.val_split, config.primary_split]
subset = metrics_long[metrics_long["split"].isin(splits_of_interest)].copy()
if subset.empty or "pearson" not in subset:
    print("Pearson metrics unavailable across requested splits; skipping split comparison plot.")
    fig_split_compare = None
else:
    subset = subset[["model_display", "split", "pearson"]].dropna()
    split_labels = {
        config.train_split: "Train",
        config.val_split: "Val",
        config.primary_split: "Test",
    }
    subset["split_label"] = subset["split"].map(split_labels).fillna(subset["split"].str.title())
    test_means = summary_reset.groupby("model_display")["test_pearson_mean"].mean().sort_values(ascending=False)
    ordered_models = [model for model in test_means.index if model in subset["model_display"].unique()]
    if not ordered_models:
        ordered_models = model_display_order
    split_order = [split_labels[split] for split in splits_of_interest if split in split_labels]
    box_colors = sns.color_palette("Set2", n_colors=len(split_order))
    box_palette = dict(zip(split_order, box_colors))
    fig_width = max(12, 0.75 * max(6, len(ordered_models)))
    fig_split_compare, ax = plt.subplots(figsize=(fig_width, 6.5))
    sns.boxplot(
        data=subset,
        x="model_display",
        y="pearson",
        hue="split_label",
        order=ordered_models,
        hue_order=split_order,
        palette=box_palette,
        width=0.65,
        fliersize=0,
        ax=ax,
    )
    jitter_sample = subset.sample(min(len(subset), 4000), random_state=config.random_seed) if len(subset) > 4000 else subset
    sns.stripplot(
        data=jitter_sample,
        x="model_display",
        y="pearson",
        hue="split_label",
        order=ordered_models,
        hue_order=split_order,
        palette=box_palette,
        dodge=True,
        jitter=0.12,
        size=3.0,
        alpha=0.5,
        edgecolor="#2b2b2b",
        linewidth=0.4,
        marker="o",
        ax=ax,
        legend=False,
    )
    metric_min = subset["pearson"].min()
    metric_max = subset["pearson"].max()
    ymin = min(-0.5, metric_min - 0.05)
    ymax = max(1.0, metric_max + 0.05)
    ax.set_ylim(ymin, ymax)
    ax.set_xlabel("Model")
    ax.set_ylabel("Per-gene Pearson")
    ax.set_title("Per-gene Pearson by Split and Model")
    plt.setp(ax.get_xticklabels(), rotation=35, ha="right")
    ax.legend(title="Split", bbox_to_anchor=(1.02, 1), loc="upper left")
    sns.despine(fig_split_compare, left=True, bottom=True)
    fig_split_compare.tight_layout()
register_figure(FIGURES, "split_comparison_overview", fig_split_compare)
if fig_split_compare is not None:
    display(fig_split_compare)
    plt.close(fig_split_compare)


## 9. Per-Gene Performance Analysis

Detailed gene-level performance heatmaps and feature importance.


In [None]:
# Per-gene heatmap: genes x models performance
metrics_wide = analysis_state.get("metrics_wide", pd.DataFrame())
fig_gene_heatmap = None

if not metrics_wide.empty and f"{config.primary_split}_pearson" in metrics_wide:
    test_pearson_per_gene = metrics_wide.pivot_table(
        index="gene",
        columns="model_display",
        values=f"{config.primary_split}_pearson",
        aggfunc="mean",
    )

    if not test_pearson_per_gene.empty and len(test_pearson_per_gene) > 3:
        # Order genes by mean performance
        gene_means = test_pearson_per_gene.mean(axis=1).sort_values(ascending=False)
        test_pearson_per_gene = test_pearson_per_gene.loc[gene_means.index]

        # Reorder columns by model performance
        col_means = test_pearson_per_gene.mean(axis=0).sort_values(ascending=False)
        test_pearson_per_gene = test_pearson_per_gene[col_means.index]

        fig_height = max(8, 0.15 * len(test_pearson_per_gene))
        fig_width = max(9, 0.4 * len(test_pearson_per_gene.columns))
        fig_gene_heatmap, ax = plt.subplots(figsize=(fig_width, fig_height))

        vmin, vmax = compute_heatmap_limits(test_pearson_per_gene.values)
        sns.heatmap(
            test_pearson_per_gene,
            cmap="RdYlGn",
            vmin=vmin,
            vmax=vmax,
            annot=False,
            cbar_kws={"label": "Test Pearson"},
            ax=ax,
        )
        ax.set_title("Per-Gene Test Pearson by Model (sorted by mean performance)")
        ax.set_xlabel("Model")
        ax.set_ylabel("Gene")
        plt.setp(ax.get_xticklabels(), rotation=45, ha="right", fontsize=8)
        plt.setp(ax.get_yticklabels(), fontsize=7)
        fig_gene_heatmap.tight_layout()
        register_figure(FIGURES, "per_gene_heatmap", fig_gene_heatmap)
        display(fig_gene_heatmap)
        plt.close(fig_gene_heatmap)
    else:
        print("Insufficient gene-level data for heatmap.")
else:
    print("Test Pearson per-gene data unavailable for heatmap.")

In [None]:
# Feature importance - top features for best models
run_df = analysis_state.get("run_df", pd.DataFrame())
summary_reset = analysis_state.get("summary_reset", pd.DataFrame())

# Helper function to get best model for a dataset
def _best_model_for_run(summary_df: pd.DataFrame) -> Optional[pd.Series]:
    if summary_df.empty:
        return None
    if "test_pearson_mean" in summary_df:
        summary_df = summary_df.sort_values("test_pearson_mean", ascending=False)
    return summary_df.iloc[0]

fig_top_features = None
best_models = []

# Use the best model overall from summary_reset
best_row = _best_model_for_run(summary_reset)

if best_row is not None and not run_df.empty:
    run_match = run_df[
        (run_df["run_name"] == best_row["run_name"]) & (run_df["model_id"] == best_row["model_id"])
    ]
    if not run_match.empty:
        row = run_match.iloc[0]
        model_dir = row["model_path"] if isinstance(row["model_path"], Path) else Path(row["model_path"])
        
        # Load feature importance using existing helper function
        from pathlib import Path
        patterns = [
            "feature_importance*.csv",
            "feature_importances*.csv",
            "feature_importance*.tsv",
            "feature_importances*.tsv",
            "feature_importance*.parquet",
            "feature_importances*.parquet",
        ]
        candidates = []
        for pattern in patterns:
            candidates.extend(model_dir.glob(pattern))
        if not candidates:
            for pattern in patterns:
                candidates.extend(model_dir.glob(f"**/{pattern}"))
        
        unique_candidates = []
        seen = set()
        for path in candidates:
            resolved = path.resolve()
            if resolved in seen or resolved.suffix.lower() in {".png", ".jpg", ".jpeg"}:
                continue
            seen.add(resolved)
            unique_candidates.append(resolved)
        
        importance_df = None
        for candidate in sorted(unique_candidates):
            try:
                if candidate.suffix.lower() == ".parquet":
                    df = pd.read_parquet(candidate)
                else:
                    sep = "\t" if candidate.suffix.lower() in {".tsv", ".txt"} else ","
                    df = pd.read_csv(candidate, sep=sep)
            except Exception as exc:
                continue
            if df.empty:
                continue
            lower_cols = {col.lower(): col for col in df.columns}
            feature_col = next((
                lower_cols[key]
                for key in ("feature", "feature_name", "name", "variable", "feature_id", "column")
                if key in lower_cols
            ), None)
            importance_col = next((
                lower_cols[key]
                for key in (
                    "importance",
                    "importance_score",
                    "importance_mean",
                    "score",
                    "value",
                    "gain",
                    "weight",
                )
                if key in lower_cols
            ), None)
            if feature_col is None or importance_col is None:
                continue
            out = df.copy()
            out.rename(columns={feature_col: "feature", importance_col: "importance"}, inplace=True)
            out["feature"] = out["feature"].astype(str)
            out["importance"] = pd.to_numeric(out["importance"], errors="coerce")
            out = out.dropna(subset=["feature", "importance"])
            if out.empty:
                continue
            extra_cols = [col for col in out.columns if col not in {"feature", "importance"}]
            out = out[["feature", "importance", *extra_cols]]
            out.sort_values("importance", ascending=False, inplace=True)
            importance_df = out.reset_index(drop=True)
            break
        
        if importance_df is not None and not importance_df.empty and "feature" in importance_df.columns and "importance" in importance_df.columns:
            importance_df = importance_df.nlargest(config.top_gene_count, "importance").copy()
            best_models.append((importance_df, best_row["model_display"]))

if best_models:
    fig_top_features, ax = plt.subplots(figsize=(10, 6))
    importance_df, model_name = best_models[0]
    importance_df = importance_df.copy()
    importance_df["feature_short"] = importance_df["feature"].str.split("|").str[-1].str[:20]
    ax.barh(range(len(importance_df)), importance_df["importance"], color="#2c7fb8", alpha=0.85)
    ax.set_yticks(range(len(importance_df)))
    ax.set_yticklabels(importance_df["feature_short"], fontsize=9)
    ax.set_xlabel("Importance")
    ax.set_title(f"Top {config.top_gene_count} Features | Best model: {model_name}")
    ax.invert_yaxis()
    sns.despine(ax=ax, left=True, bottom=True)
    fig_top_features.tight_layout()
    register_figure(FIGURES, "top_features", fig_top_features)
    display(fig_top_features)
    plt.close(fig_top_features)
else:
    print("No feature importance data available for visualization.")

## 10. Embryonic vs Endothelial Comparisons

Train/val boxplots and test-set violin plots for each dataset.


In [None]:
analysis_state = globals().get("analysis_state", {})
metrics_long = analysis_state.get("metrics_long")
run_df = analysis_state.get("run_df")
model_display_order = analysis_state.get("model_display_order", [])

EMBRYONIC_RUN_DIRECTORY_SELECTION: list[str | Path] = [
    # Examples:
    # "spear_100genes_cpu_embryonic_20251218_114225_xgboost",
]
ENDOTHELIAL_RUN_DIRECTORY_SELECTION: list[str | Path] = [
    # Examples:
    # "spear_100genes_cpu_endothelial_20251220_091232_xgboost",
]
EMBRYONIC_RUN_GLOB_SELECTION: list[str] = [
    "*embryonic*",
]
ENDOTHELIAL_RUN_GLOB_SELECTION: list[str] = [
    "*endothelial*",
]

def _select_run_names(run_df: pd.DataFrame, entries: list[str | Path], globs: list[str]) -> set[str]:
    if run_df is None or run_df.empty:
        return set()
    manual_names = [name for name in (_resolve_run_name(entry) for entry in entries) if name]
    glob_patterns = [pattern.strip() for pattern in globs if pattern and pattern.strip()]
    selected = set(manual_names)
    if glob_patterns:
        for run_name in run_df["run_name"].dropna().unique():
            if _matches_any(run_name, glob_patterns):
                selected.add(run_name)
    return selected

def _summarize_selection(label: str, run_names: set[str], all_runs: list[str]) -> None:
    if run_names:
        print(f"{label} runs ({len(run_names)}):", sorted(run_names))
    else:
        print(f"No {label.lower()} runs matched.")
        print("Available run names:", all_runs)

def _subset_metrics(metrics_long: pd.DataFrame, run_names: set[str], splits: list[str], dataset: str) -> pd.DataFrame:
    if metrics_long is None or metrics_long.empty or not run_names:
        return pd.DataFrame()
    subset = metrics_long[
        metrics_long["run_name"].isin(run_names)
        & metrics_long["split"].isin(splits)
    ].copy()
    if subset.empty:
        return pd.DataFrame()
    subset["dataset"] = dataset
    return subset

if metrics_long is None or run_df is None or run_df.empty:
    print("Run metadata unavailable; skipping dataset comparison figures.")
    fig_train_val_box = None
    fig_test_violin_by_dataset = None
else:
    all_runs = sorted(run_df["run_name"].dropna().unique())
    endothelial_runs = _select_run_names(run_df, ENDOTHELIAL_RUN_DIRECTORY_SELECTION, ENDOTHELIAL_RUN_GLOB_SELECTION)
    embryonic_runs = _select_run_names(run_df, EMBRYONIC_RUN_DIRECTORY_SELECTION, EMBRYONIC_RUN_GLOB_SELECTION)

    if not embryonic_runs and endothelial_runs and not EMBRYONIC_RUN_DIRECTORY_SELECTION and not EMBRYONIC_RUN_GLOB_SELECTION:
        embryonic_runs = set(all_runs) - endothelial_runs
    elif not endothelial_runs and embryonic_runs and not ENDOTHELIAL_RUN_DIRECTORY_SELECTION and not ENDOTHELIAL_RUN_GLOB_SELECTION:
        endothelial_runs = set(all_runs) - embryonic_runs

    _summarize_selection("Embryonic", embryonic_runs, all_runs)
    _summarize_selection("Endothelial", endothelial_runs, all_runs)

    embryonic_test = _subset_metrics(metrics_long, embryonic_runs, [config.primary_split], "Embryonic")
    endothelial_test = _subset_metrics(metrics_long, endothelial_runs, [config.primary_split], "Endothelial")
    embryonic_train_val = _subset_metrics(metrics_long, embryonic_runs, [config.train_split, config.val_split], "Embryonic")
    endothelial_train_val = _subset_metrics(metrics_long, endothelial_runs, [config.train_split, config.val_split], "Endothelial")

    combined_test = pd.concat([embryonic_test, endothelial_test], ignore_index=True)
    combined_train_val = pd.concat([embryonic_train_val, endothelial_train_val], ignore_index=True)

    # Ensure expected label column exists for plotting.
    if not combined_train_val.empty and 'model_display_name' not in combined_train_val and 'model_display' in combined_train_val:
        combined_train_val['model_display_name'] = combined_train_val['model_display']
    if not combined_test.empty and 'model_display_name' not in combined_test and 'model_display' in combined_test:
        combined_test['model_display_name'] = combined_test['model_display']

    if combined_train_val.empty and combined_test.empty:
        print("No data for dataset comparison.")
        fig_train_val_box = None
        fig_test_violin_by_dataset = None
    else:
        # --- Train/Val Boxplot by Dataset ---
        if not combined_train_val.empty:
            fig_train_val_box, ax = plt.subplots(figsize=(12, 6))
            sns.boxplot(
                data=combined_train_val,
                x="model_display_name",
                y="pearson",
                hue="dataset",
                ax=ax,
                palette="Set2",
            )
            ax.set_xlabel("Model")
            ax.set_ylabel("Pearson Correlation")
            ax.set_title("Train/Val Pearson by Dataset (Grouped)")
            ax.legend(title="Dataset")
            plt.xticks(rotation=45, ha="right")
            plt.tight_layout()
            FIGURES["train_val_pearson_by_dataset_box"] = fig_train_val_box
        else:
            fig_train_val_box = None

        # --- Test Violin by Dataset ---
        if not combined_test.empty:
            fig_test_violin_by_dataset, ax = plt.subplots(figsize=(12, 6))
            sns.violinplot(
                data=combined_test,
                x="model_display_name",
                y="pearson",
                hue="dataset",
                split=True,
                ax=ax,
                palette="Set2",
            )
            ax.set_xlabel("Model")
            ax.set_ylabel("Test Pearson Correlation")
            ax.set_title("Test Pearson by Dataset (Split Violin)")
            ax.legend(title="Dataset")
            plt.xticks(rotation=45, ha="right")
            plt.tight_layout()
            FIGURES["test_pearson_by_dataset_violin"] = fig_test_violin_by_dataset
        else:
            fig_test_violin_by_dataset = None

if "analysis_state" not in globals():
    analysis_state = {}
analysis_state["fig_train_val_box"] = fig_train_val_box
analysis_state["fig_test_violin_by_dataset"] = fig_test_violin_by_dataset

In [None]:
import re

test_metrics = analysis_state["test_metrics"]
summary_reset = analysis_state["summary_reset"]
short_map = analysis_state.get("model_short_name_map", MODEL_ID_TO_SHORT)

top_gene_count = 30

if test_metrics.empty:
    print("Test metrics unavailable; skipping top-gene visualisation.")
    analysis_state["top_gene_figure_keys"] = []
    fig_top_genes = None
else:
    if "test_pearson_mean" in summary_reset:
        ranked_models = summary_reset.sort_values("test_pearson_mean", ascending=False)
    else:
        ranked_models = summary_reset
    top_model_limit = 1  # only the best model for top-gene performance plots
    top_models = ranked_models.drop_duplicates("model_id").head(top_model_limit)[["model_id", "model_display"]]
    gene_frames: list[pd.DataFrame] = []
    for row in top_models.itertuples(index=False):
        model_subset = test_metrics[test_metrics["model_id"] == row.model_id]
        if model_subset.empty:
            continue
        gene_stats = (
            model_subset.groupby("gene")["pearson"].agg(mean="mean", std="std", count="count").reset_index()
        )
        gene_stats.rename(
            columns={
                "mean": "mean_test_pearson",
                "std": "pearson_std",
                "count": "observation_count",
            },
            inplace=True,
        )
        short_name = short_map.get(row.model_id, to_short_name(row.model_display)) or row.model_display
        gene_stats["model_display_long"] = row.model_display
        gene_stats["model_display_short"] = short_name
        gene_frames.append(gene_stats)
    if not gene_frames:
        print("No genes available for top-model bar chart.")
        analysis_state["top_gene_figure_keys"] = []
        fig_top_genes = None
    else:
        top_gene_df = pd.concat(gene_frames, ignore_index=True)
        figure_keys: list[str] = []
        for long_name, short_name in (
            top_gene_df[["model_display_long", "model_display_short"]]
            .drop_duplicates()
            .itertuples(index=False, name=None)
        ):
            group = top_gene_df[top_gene_df["model_display_long"] == long_name]
            ordered = group.sort_values("mean_test_pearson", ascending=False).head(top_gene_count)
            if ordered.empty:
                continue
            fig, ax = plt.subplots(figsize=(9, 6))
            colors = sns.color_palette("crest", n_colors=len(ordered))
            bars = ax.barh(
                ordered["gene"],
                ordered["mean_test_pearson"],
                color=colors,
                edgecolor="#2b2b2b",
                linewidth=0.4,
            )
            ax.set_title(f"Top {top_gene_count} Genes | {short_name}")
            ax.set_xlabel("Mean Test Pearson")
            ax.set_ylabel("Gene")
            data_max = ordered["mean_test_pearson"].max()
            label_offset = max(0.01, data_max * 0.01)
            x_max = max(1.0, data_max + 0.05) + label_offset * 4
            ax.set_xlim(0, x_max)
            ax.grid(axis="x", linestyle="--", alpha=0.3)
            ax.invert_yaxis()
            for bar, mean_val in zip(bars, ordered["mean_test_pearson"]):
                y_pos = bar.get_y() + bar.get_height() / 2
                ax.text(
                    mean_val + label_offset,
                    y_pos,
                    f"{mean_val:.3f}",
                    ha="left",
                    va="center",
                    fontsize=9,
                    color="#1a1a1a",
                )
            sns.despine(ax=ax, left=True, bottom=True)
            plt.tight_layout()
            slug = re.sub(r"[^0-9a-zA-Z]+", "_", short_name.lower()).strip("_") or "model"
            key = f"top_genes_test_performance_{slug}"
            register_figure(FIGURES, key, fig)
            display(fig)
            plt.close(fig)
            figure_keys.append(key)
        analysis_state["top_gene_figure_keys"] = figure_keys
        fig_top_genes = None


# --- Metric distributions and heatmaps (test split) ---
metric_df = analysis_state.get("metrics_long")
if metric_df is None or metric_df.empty:
    print("Metrics unavailable; skipping metric boxplots and heatmaps.")
else:
    test_metric_df = metric_df[metric_df["split"] == config.primary_split].copy()
    if "model_display" not in test_metric_df.columns and "model_display_name" in test_metric_df.columns:
        test_metric_df["model_display"] = test_metric_df["model_display_name"]

    metrics_to_plot = {
        "pearson": "Pearson",
        "spearman": "Spearman",
        "r2": "R2",
        "rmse": "RMSE",
    }

    pearson_order = (
        summary_reset.groupby("model_display")["test_pearson_mean"].mean().sort_values(ascending=False).index
        if "test_pearson_mean" in summary_reset
        else []
    )

    for metric_key, title_label in metrics_to_plot.items():
        if metric_key not in test_metric_df.columns:
            continue
        order = [name for name in pearson_order if name in test_metric_df["model_display"].unique()]
        if not order:
            order = sorted(test_metric_df["model_display"].dropna().unique())

        fig, ax = plt.subplots(figsize=(12, 6))
        sns.violinplot(
            data=test_metric_df,
            x="model_display",
            y=metric_key,
            order=order,
            ax=ax,
            palette="Set2",
            inner=None,
            cut=0,
            linewidth=1.0,
        )

        # Clip violins to the right half.
        for idx, poly in enumerate(ax.collections[:len(order)]):
            for path in poly.get_paths():
                verts = path.vertices
                verts[:, 0] = np.clip(verts[:, 0], idx, np.inf)

        # Left-half boxplot.
        data_by_model = [
            test_metric_df.loc[test_metric_df["model_display"] == name, metric_key].dropna().values
            for name in order
        ]
        positions = np.arange(len(order)) - 0.2
        box = ax.boxplot(
            data_by_model,
            positions=positions,
            widths=0.2,
            patch_artist=True,
            showfliers=False,
            medianprops={"color": "#222222", "linewidth": 1.0},
        )
        for patch in box["boxes"]:
            patch.set(facecolor="#ffffff", edgecolor="#222222", linewidth=1.0)
        for whisker in box["whiskers"]:
            whisker.set(color="#222222", linewidth=1.0)
        for cap in box["caps"]:
            cap.set(color="#222222", linewidth=1.0)

        # Scatter overlay aligned with the boxplots.
        rng_local = np.random.default_rng(config.random_seed)
        for idx, vals in enumerate(data_by_model):
            if len(vals) == 0:
                continue
            jitter = rng_local.normal(0, 0.03, size=len(vals))
            x_vals = np.full(len(vals), positions[idx]) + jitter
            ax.scatter(
                x_vals,
                vals,
                s=10,
                alpha=0.35,
                color="#2b2b2b",
                linewidth=0,
            )

        ax.set_title(f"Test {title_label} by Model")
        ax.set_xlabel("Model")
        ax.set_ylabel(title_label)
        ax.set_xticks(np.arange(len(order)))
        ax.set_xticklabels(order, rotation=45, ha="right")
        ax.set_xlim(-0.6, len(order) - 0.4)
        sns.despine(ax=ax, left=True, bottom=True)
        plt.tight_layout()
        register_figure(FIGURES, f"test_{metric_key}_boxplot", fig)
        display(fig)
        plt.close(fig)

    # Combined mean-metric heatmap (Pearson/Spearman/R2/RMSE)
    metric_means = {}
    for metric_key, title_label in metrics_to_plot.items():
        if metric_key not in test_metric_df.columns:
            continue
        series = test_metric_df.groupby("model_display")[metric_key].mean().dropna()
        if not series.empty:
            metric_means[title_label] = series

    if metric_means:
        combined_df = pd.DataFrame(metric_means)
        model_order = summary_reset[["model_display", "model_id"]].drop_duplicates("model_id")["model_display"].tolist()
        ordered_index = [name for name in model_order if name in combined_df.index]
        combined_df = combined_df.reindex(index=ordered_index)
        fig_height = max(4, 0.4 * len(combined_df.index))
        metric_cols = list(combined_df.columns)
        fig, axes = plt.subplots(
            nrows=1,
            ncols=len(metric_cols),
            figsize=(1.6 * len(metric_cols), fig_height),
            squeeze=False,
        )
        for idx, metric in enumerate(metric_cols):
            ax = axes[0][idx]
            col_data = combined_df[[metric]]
            vmin_col, vmax_col = compute_heatmap_limits(col_data.values, lower_percentile=10.0, upper_percentile=95.0)
            sns.heatmap(
                col_data,
                cmap="viridis",
                vmin=vmin_col,
                vmax=vmax_col,
                annot=True,
                fmt=".2f",
                linewidths=0.2,
                linecolor="#f5f5f5",
                cbar=False,
                ax=ax,
            )
            ax.set_title("")
            ax.set_xlabel('')
            ax.set_ylabel("Model" if idx == 0 else "")
            if idx == 0:
                ax.yaxis.labelpad = 20
            ax.set_yticklabels(ax.get_yticklabels(), fontsize=8)
            if idx != 0:
                ax.set_yticklabels([])
        fig.suptitle("Summary metrics (test means)", fontsize=12)
        sns.despine(fig, left=True, bottom=True)
        plt.tight_layout()
        register_figure(FIGURES, "test_metric_mean_heatmap_combined", fig)
        display(fig)
        plt.close(fig)


## 11. Generalization Gap Overview

Train vs test Pearson summary for each model.


In [None]:
summary_reset = analysis_state["summary_reset"]
required_cols = {"train_pearson_mean", "test_pearson_mean"}
fig_generalization_gap = None
if not required_cols.issubset(summary_reset.columns):
    print("Train/test summary columns unavailable; skipping generalization gap plot.")
else:
    ranked = summary_reset.sort_values("test_pearson_mean", ascending=False)
    best_per_model = ranked.drop_duplicates("model_id")
    gap_df = best_per_model[["model_display", "train_pearson_mean", "test_pearson_mean"]].copy()
    gap_df["generalization_gap"] = gap_df["train_pearson_mean"] - gap_df["test_pearson_mean"]
    gap_df.sort_values("generalization_gap", ascending=False, inplace=True)
    fig_height = max(4, 0.35 * len(gap_df))
    fig_generalization_gap, ax = plt.subplots(figsize=(8, fig_height))
    colors = sns.color_palette("mako", n_colors=len(gap_df))
    bars = ax.barh(
        gap_df["model_display"],
        gap_df["generalization_gap"],
        color=colors,
        linewidth=0,
    )
    ax.axvline(0.0, color="#6d6d6d", linestyle="--", linewidth=1)
    gap_min = gap_df["generalization_gap"].min()
    gap_max = gap_df["generalization_gap"].max()
    span = max(0.01, gap_max - gap_min)
    margin = max(0.08, span * 0.12)
    xmin = min(gap_min - margin, gap_min - 0.02)
    xmin = min(xmin, -0.15)
    xmax = max(gap_max + margin, 0.15)
    ax.set_xlim(xmin, xmax)
    ax.set_xlabel("Train minus Test Mean Pearson")
    ax.set_ylabel("Model")
    ax.set_title("Generalization Gap (Train - Test)")
    label_offset = (xmax - xmin) * 0.02
    for bar, value in zip(bars, gap_df["generalization_gap"]):
        y = bar.get_y() + bar.get_height() / 2
        if value >= 0:
            ax.text(value + label_offset, y, f"{value:.3f}", va="center", ha="left", fontsize=9, color="#1f1f1f")
        else:
            ax.text(value - label_offset, y, f"{value:.3f}", va="center", ha="right", fontsize=9, color="#1f1f1f")
    sns.despine(fig_generalization_gap, left=True, bottom=True)
    fig_generalization_gap.tight_layout()
register_figure(FIGURES, "generalization_gap", fig_generalization_gap)
if fig_generalization_gap is not None:
    display(fig_generalization_gap)
    plt.close(fig_generalization_gap)


## 12. Top-Model Diagnostics

Quick previews for the best-performing models.


In [None]:
summary_reset = analysis_state["summary_reset"]
run_df = analysis_state["run_df"]

best_models = summary_reset.head(config.top_model_count).reset_index(drop=True)
if best_models.empty:
    print("No models available for scatter plot previews.")
else:
    missing_assets: list[str] = []
    for model_row in best_models.itertuples(index=False):
        run_rows = run_df[(run_df["model_id"] == model_row.model_id) & (run_df["run_name"] == model_row.run_name)]
        if run_rows.empty:
            missing_assets.append(
                f"Missing run directory metadata for {model_row.model_display} (run {model_row.run_name})."
            )
            continue
        row = run_rows.iloc[0]
        model_dir_raw = row["model_path"]
        model_dir = model_dir_raw if isinstance(model_dir_raw, Path) else Path(model_dir_raw)
        if not model_dir.exists():
            missing_assets.append(
                f"Model directory missing on disk for {model_row.model_display} (run {model_row.run_name})."
            )
            continue
        scatter_path = model_dir / "scatter_test.png"
        if scatter_path.exists():
            display(Markdown(
                f"**{model_row.model_display}** (run `{model_row.run_name}`) — `scatter_test.png`"
            ))
            display(Image(filename=str(scatter_path)))
        else:
            predictions_raw = row.get("predictions_path") if isinstance(row, pd.Series) else None
            predictions_path = None
            if predictions_raw:
                predictions_path = predictions_raw if isinstance(predictions_raw, Path) else Path(predictions_raw)
            if predictions_path and predictions_path.exists():
                missing_assets.append(
                    f"`scatter_test.png` not found for {model_row.model_display} (run {model_row.run_name}). "
                    f"Checked `{to_relative_path(model_dir, config.project_root)}`."
                )
            else:
                missing_assets.append(
                    f"No test predictions available for {model_row.model_display} (run {model_row.run_name}); skipping scatter preview."
                )
    if missing_assets:
        for message in missing_assets:
            print(message)

In [None]:
summary_reset = analysis_state["summary_reset"]
run_df = analysis_state["run_df"]

best_models = summary_reset.head(config.top_model_count).reset_index(drop=True)
if best_models.empty:
    print("Best models not identified; skipping training history previews.")
else:
    missing_assets: list[str] = []
    for model_row in best_models.itertuples(index=False):
        run_rows = run_df[(run_df["model_id"] == model_row.model_id) & (run_df["run_name"] == model_row.run_name)]
        if run_rows.empty:
            missing_assets.append(
                f"Missing run directory metadata for {model_row.model_display} (run {model_row.run_name})."
            )
            continue
        row = run_rows.iloc[0]
        model_dir_raw = row["model_path"]
        model_dir = model_dir_raw if isinstance(model_dir_raw, Path) else Path(model_dir_raw)
        if not model_dir.exists():
            missing_assets.append(
                f"Model directory missing on disk for {model_row.model_display} (run {model_row.run_name})."
            )
            continue
        history_candidates = [model_dir / "training_history_loss.png"]
        history_dir = model_dir / "histories"
        if history_dir.exists():
            history_candidates.extend(sorted(history_dir.glob("*_loss.png")))
        history_path = next((path for path in history_candidates if path.exists()), None)
        if history_path is not None:
            display(Markdown(
                f"**{model_row.model_display}** (run `{model_row.run_name}`) — `{history_path.name}`"
            ))
            display(Image(filename=str(history_path)))
        else:
            missing_assets.append(
                f"No training history plot available for {model_row.model_display} (run {model_row.run_name})."
            )
    if missing_assets:
        for message in missing_assets:
            print(message)

## 13. Export Artifacts

Persist figures and tables to disk.


In [None]:
# Additional exploratory figures (moved from manuscript figures)
metrics_wide = analysis_state.get("metrics_wide")
summary_reset = analysis_state.get("summary_reset")

# Per-gene heatmap
fig_gene_heatmap = None
if isinstance(metrics_wide, pd.DataFrame) and not metrics_wide.empty and f"{config.primary_split}_pearson" in metrics_wide:
    test_pearson_per_gene = metrics_wide.pivot_table(
        index="gene",
        columns="model_display",
        values=f"{config.primary_split}_pearson",
        aggfunc="mean",
    )
    if not test_pearson_per_gene.empty and len(test_pearson_per_gene) > 3:
        gene_means = test_pearson_per_gene.mean(axis=1).sort_values(ascending=False)
        test_pearson_per_gene = test_pearson_per_gene.loc[gene_means.index]
        col_means = test_pearson_per_gene.mean(axis=0).sort_values(ascending=False)
        test_pearson_per_gene = test_pearson_per_gene[col_means.index]
        fig_height = max(8, 0.15 * len(test_pearson_per_gene))
        fig_width = max(9, 0.4 * len(test_pearson_per_gene.columns))
        fig_gene_heatmap, ax = plt.subplots(figsize=(fig_width, fig_height))
        vmin, vmax = compute_heatmap_limits(test_pearson_per_gene.values)
        sns.heatmap(
            test_pearson_per_gene,
            cmap="RdYlGn",
            vmin=vmin,
            vmax=vmax,
            annot=False,
            cbar_kws={"label": "Test Pearson"},
            ax=ax,
        )
        ax.set_title("Per-Gene Test Pearson by Model (sorted by mean performance)")
        ax.set_xlabel("Model")
        ax.set_ylabel("Gene")
        plt.setp(ax.get_xticklabels(), rotation=45, ha="right", fontsize=8)
        plt.setp(ax.get_yticklabels(), fontsize=7)
        fig_gene_heatmap.tight_layout()
        register_figure(FIGURES, "per_gene_heatmap", fig_gene_heatmap)
        display(fig_gene_heatmap)
        plt.close(fig_gene_heatmap)

# Top features (best model)
fig_top_features = None
best_details = analysis_state.get("best_model_details")
if best_details is None:
    print("Best model context missing; skipping top feature summary.")
else:
    importance_df, _ = _load_feature_importance_table(best_details["model_dir"])
    if importance_df.empty:
        print("No feature importance data available for top-features plot.")
    else:
        top_df = importance_df.nlargest(20, "importance").copy()
        top_df["feature_short"] = top_df["feature"].str.split("|").str[-1].str[:24]
        fig_top_features, ax = plt.subplots(figsize=(9, 6))
        ax.barh(range(len(top_df)), top_df["importance"], color="#2c7fb8", alpha=0.85)
        ax.set_yticks(range(len(top_df)))
        ax.set_yticklabels(top_df["feature_short"], fontsize=8)
        ax.set_xlabel("Importance")
        ax.set_title(f"Top Features | Best Model: {best_details['model_display']}")
        ax.invert_yaxis()
        sns.despine(fig_top_features, left=True, bottom=True)
        fig_top_features.tight_layout()
        register_figure(FIGURES, "top_features", fig_top_features)
        display(fig_top_features)
        plt.close(fig_top_features)


In [None]:
FIGURE_SAVE_PLAN: dict[str, list[Path]] = {
    "test_pearson_heatmap_all_models": [config.fig_dir / "test_pearson_heatmap_all_models.png"],
    "test_pearson_violin": [config.fig_dir / "test_pearson_violin.png"],
    "split_comparison_overview": [config.fig_dir / "split_comparison_overview.png"],
    "test_pearson_heatmap_top20_per_model_blocks": [config.fig_dir / "test_pearson_heatmap_top20_per_model_blocks.png"],
    "test_pearson_boxplot": [config.fig_dir / "test_pearson_boxplot.png"],
    "test_spearman_boxplot": [config.fig_dir / "test_spearman_boxplot.png"],
    "test_r2_boxplot": [config.fig_dir / "test_r2_boxplot.png"],
    "test_rmse_boxplot": [config.fig_dir / "test_rmse_boxplot.png"],
    "test_metric_mean_heatmap_combined": [config.fig_dir / "test_metric_mean_heatmap_combined.png"],
    "generalization_gap": [config.fig_dir / "generalization_gap.png"],
    "best_model_feature_importance_top": [config.fig_dir / "best_model_feature_importance_top.png"],
    "train_val_box_by_dataset": [config.fig_dir / "train_val_box_by_dataset.png"],
    "test_violin_by_dataset": [config.fig_dir / "test_violin_by_dataset.png"],
    "resource_usage_summary": [config.fig_dir / "resource_usage_summary.png"],
    "per_gene_heatmap": [config.fig_dir / "per_gene_heatmap.png"],
    "top_features": [config.fig_dir / "top_features.png"],
}

top_gene_keys = analysis_state.get("top_gene_figure_keys", [])
for key in top_gene_keys:
    filename = f"{key}.png"
    FIGURE_SAVE_PLAN[key] = [config.fig_dir / filename]

TABLE_SAVE_PLAN: dict[str, list[Path]] = {
    "metrics_per_gene_master": [config.reports_dir / "metrics_per_gene_master.csv"],
    "summary_metrics_all_models": [config.reports_dir / "summary_metrics_all_models.csv"],
}

saved_figures: list[Path] = []
for key, targets in FIGURE_SAVE_PLAN.items():
    fig = FIGURES.get(key)
    if fig is None:
        print(f"Skipping figure '{key}' (not generated).")
        continue
    for target in targets:
        target.parent.mkdir(parents=True, exist_ok=True)
        fig.savefig(target, bbox_inches="tight")
        saved_figures.append(target)

saved_tables: list[Path] = []
for key, targets in TABLE_SAVE_PLAN.items():
    table = TABLES.get(key)
    if table is None or table.empty:
        print(f"Skipping table '{key}' (not generated or empty).")
        continue
    for target in targets:
        target.parent.mkdir(parents=True, exist_ok=True)
        table.to_csv(target, index=False)
        saved_tables.append(target)

print("Saved figures:")
for target in saved_figures:
    print(f" - {target}")

print("Saved tables:")
for target in saved_tables:
    print(f" - {target}")

if saved_figures:
    analysis_metadata["saved_figures"] = [to_relative_path(path, config.project_root) for path in saved_figures]


## 14. Session Metadata

Run summary details for logging and provenance.


In [None]:
generated_at = analysis_metadata.get("generated_at", "n/a")
results_root = analysis_metadata.get("results_root", "n/a")
run_count = analysis_metadata.get("run_count", "n/a")
model_count = analysis_metadata.get("model_count", "n/a")
best_display = analysis_metadata.get("best_model_display", "n/a")
best_id = analysis_metadata.get("best_model_id", "n/a")
best_run = analysis_metadata.get("best_run_name", "n/a")
report_markdown = f"""### Analysis Metadata
- Generated at: `{generated_at}`
- Results root: `{results_root}`
- Runs analysed: `{run_count}`
- Models analysed: `{model_count}`
- Best model: `{best_display}` (`{best_id}`)
- Source run: `{best_run}`"""
display(Markdown(report_markdown))

## 15. Feature Importance (Best Model)

Load feature-importance exports for the best-performing model.


In [None]:
best_details = analysis_state.get("best_model_details")
if best_details is None:
    summary_reset = analysis_state.get("summary_reset")
    run_df = analysis_state.get("run_df")
    if summary_reset is None or summary_reset.empty:
        raise RuntimeError("Best-model summary unavailable; please rerun Section 5 before using Section 13.")
    if run_df is None or run_df.empty:
        raise RuntimeError("Run metadata missing; rerun Section 4 before using Section 13.")
    leader = summary_reset.iloc[0]
    best_model_id = str(leader["model_id"])
    best_run_name = str(leader["run_name"])
    best_model_display = str(leader.get("model_display", best_model_id))
    match = run_df[(run_df["model_id"] == best_model_id) & (run_df["run_name"] == best_run_name)]
    if match.empty:
        raise RuntimeError(f"Unable to locate run folder for {best_model_display} (run {best_run_name}).")
    row = match.iloc[0]
    model_path_value = row.get("model_path")
    if isinstance(model_path_value, Path):
        best_model_dir = model_path_value.resolve()
    elif isinstance(model_path_value, str) and model_path_value:
        best_model_dir = Path(model_path_value).resolve()
    else:
        raise RuntimeError("Model directory not recorded for the leading model.")
    run_path_value = row.get("run_path")
    if isinstance(run_path_value, Path):
        best_run_dir = run_path_value.resolve()
    elif isinstance(run_path_value, str) and run_path_value:
        best_run_dir = Path(run_path_value).resolve()
    else:
        best_run_dir = best_model_dir.parent
    metrics_value = row.get("metrics_path")
    if isinstance(metrics_value, Path):
        metrics_path = metrics_value.resolve()
    elif isinstance(metrics_value, str) and metrics_value:
        metrics_path = Path(metrics_value).resolve()
    else:
        metrics_path = None
    preds_value = row.get("predictions_path")
    if isinstance(preds_value, Path):
        predictions_path = preds_value.resolve()
    elif isinstance(preds_value, str) and preds_value:
        predictions_path = Path(preds_value).resolve()
    else:
        predictions_path = None
    best_details = {
        "model_id": best_model_id,
        "model_display": best_model_display,
        "run_name": best_run_name,
        "model_dir": best_model_dir,
        "run_dir": best_run_dir,
        "metrics_path": metrics_path,
        "predictions_path": predictions_path,
    }
    analysis_state["best_model_details"] = best_details
relative_model_dir = to_relative_path(best_details["model_dir"], config.project_root)
summary_lines = [
    f"**Best model context**: `{best_details['model_display']}` (`{best_details['model_id']}`) from run `{best_details['run_name']}`",
    "",
    f"- Model directory: `{relative_model_dir}`",
]
if best_details.get("metrics_path") is not None:
    summary_lines.append(
        f"- Metrics file: `{to_relative_path(best_details['metrics_path'], config.project_root)}`"
)
if best_details.get("predictions_path") is not None:
    summary_lines.append(
        f"- Predictions file: `{to_relative_path(best_details['predictions_path'], config.project_root)}`"
)
display(Markdown("\n".join(summary_lines)))

In [None]:
def _load_feature_importance_table(model_dir: Path) -> tuple[pd.DataFrame, Optional[Path]]:
    """Locate and standardise feature importances exported by the training pipeline."""
    patterns = [
        "feature_importance*.csv",
        "feature_importances*.csv",
        "feature_importance*.tsv",
        "feature_importances*.tsv",
        "feature_importance*.parquet",
        "feature_importances*.parquet",
    ]
    candidates: list[Path] = []
    for pattern in patterns:
        candidates.extend(model_dir.glob(pattern))
    if not candidates:
        for pattern in patterns:
            candidates.extend(model_dir.glob(f"**/{pattern}"))
    unique_candidates: list[Path] = []
    seen: set[Path] = set()
    for path in candidates:
        resolved = path.resolve()
        if resolved in seen:
            continue
        if resolved.suffix.lower() in {".png", ".jpg", ".jpeg"}:
            continue
        seen.add(resolved)
        unique_candidates.append(resolved)
    for candidate in sorted(unique_candidates):
        try:
            if candidate.suffix.lower() == ".parquet":
                df = pd.read_parquet(candidate)
            else:
                sep = "\t" if candidate.suffix.lower() in {".tsv", ".txt"} else ","
                df = pd.read_csv(candidate, sep=sep)
        except Exception as exc:
            print(f"Skipping {candidate.name}: failed to load ({exc})")
            continue
        if df.empty:
            continue
        lower_cols = {col.lower(): col for col in df.columns}
        feature_col = next((
            lower_cols[key]
            for key in ("feature", "feature_name", "name", "variable", "feature_id", "column")
            if key in lower_cols
        ), None)
        importance_col = next((
            lower_cols[key]
            for key in (
                "importance",
                "importance_score",
                "importance_mean",
                "score",
                "value",
                "gain",
                "weight",
            )
            if key in lower_cols
        ), None)
        if feature_col is None or importance_col is None:
            continue
        out = df.copy()
        out.rename(columns={feature_col: "feature", importance_col: "importance"}, inplace=True)
        out["feature"] = out["feature"].astype(str)
        out["importance"] = pd.to_numeric(out["importance"], errors="coerce")
        out = out.dropna(subset=["feature", "importance"])
        if out.empty:
            continue
        extra_cols = [col for col in out.columns if col not in {"feature", "importance"}]
        out = out[["feature", "importance", *extra_cols]]
        out.sort_values("importance", ascending=False, inplace=True)
        return out.reset_index(drop=True), candidate
    return pd.DataFrame(columns=["feature", "importance"]), None

importance_df, importance_path = _load_feature_importance_table(best_details["model_dir"])
analysis_state["best_model_feature_importances"] = importance_df
analysis_state["best_model_feature_importances_path"] = importance_path
if importance_df.empty:
    display(Markdown(
        "**Feature importances**: no compatible export found in the best-model directory. "
        "Upload or regenerate a table named `feature_importance*.csv|tsv|parquet` to enable the following cells."
    ))
else:
    note_lines = [
        "**Feature importances**: detected data source.",
    ]
    if importance_path is not None:
        note_lines.append(
            f"- Source file: `{to_relative_path(importance_path, config.project_root)}`"
)
        note_lines.append(f"- Rows: {len(importance_df):,}")
    display(Markdown("\n".join(note_lines)))
    display(importance_df.head(10))

In [None]:
importance_df = analysis_state.get("best_model_feature_importances")
if importance_df is None or importance_df.empty:
    print("Feature importance data unavailable; skipping distribution plot.")
    fig_feature_importance_dist = None
else:
    importance_values = importance_df["importance"].astype(float)
    importance_values = importance_values[np.isfinite(importance_values)]
    if importance_values.empty:
        print("Importance column contains no finite values; skipping distribution plot.")
        fig_feature_importance_dist = None
    else:
        fig_feature_importance_dist, ax = plt.subplots(figsize=(6.5, 4.2))
        sns.histplot(importance_values, bins=40, kde=True, color="#377eb8", ax=ax)
        ax.set_title("Feature Importance Distribution | Best Model")
        ax.set_xlabel("Importance score")
        ax.set_ylabel("Feature count")
        sns.despine(fig_feature_importance_dist, left=True, bottom=True)
        plt.tight_layout()
        display(fig_feature_importance_dist)
        plt.close(fig_feature_importance_dist)

### Top Feature Importance (Best Model)


In [None]:
importance_df = analysis_state.get("best_model_feature_importances")
best_details = analysis_state.get("best_model_details", {})
if importance_df is None or importance_df.empty:
    print("Feature importance data unavailable; skipping top-feature plot.")
    fig_feature_importance_top = None
else:
    top_n = min(20, len(importance_df))
    top_df = importance_df.nlargest(top_n, "importance").copy()
    top_df["feature"] = top_df["feature"].astype(str)
    fig_height = max(4.0, 0.35 * top_n + 1.0)
    fig_feature_importance_top, ax = plt.subplots(figsize=(7.5, fig_height))
    sns.barplot(
        data=top_df,
        x="importance",
        y="feature",
        color="#4C78A8",
        ax=ax,
    )
    model_label = best_details.get("model_display", "Best Model")
    ax.set_title(f"Top Feature Importance | {model_label}")
    ax.set_xlabel("Importance score")
    ax.set_ylabel("Feature")
    sns.despine(fig_feature_importance_top, left=True, bottom=True)
    plt.tight_layout()
register_figure(FIGURES, "best_model_feature_importance_top", fig_feature_importance_top)
if fig_feature_importance_top is not None:
    display(fig_feature_importance_top)
    plt.close(fig_feature_importance_top)


In [None]:
importance_df = analysis_state.get("best_model_feature_importances")
distance_col = None
if importance_df is None or importance_df.empty:
    print("Feature importance data unavailable; skipping TSS distance analysis.")
else:
    candidate_cols = [
        "distance_to_tss_bp",
        "distance_to_tss",
        "distance_bp",
        "tss_distance_bp",
    ]
    for candidate in candidate_cols:
        if candidate in importance_df.columns:
            distance_col = candidate
            break
    if distance_col is None:
        bin_pattern = re.compile(r"bin_(-?\d+)_to_(-?\d+)")
        def _infer_distance(feature_name: str) -> Optional[float]:
            if not isinstance(feature_name, str):
                return None
            token = feature_name.split("|", 1)[-1]
            match = bin_pattern.search(token)
            if match:
                start_bp = float(match.group(1))
                end_bp = float(match.group(2))
                return 0.5 * (start_bp + end_bp)
            return None
        inferred = importance_df["feature"].map(_infer_distance)
        if inferred.notna().any():
            importance_df = importance_df.copy()
            importance_df["distance_to_tss_bp"] = inferred
            distance_col = "distance_to_tss_bp"
            analysis_state["best_model_feature_importances"] = importance_df
    if distance_col is None:
        print("Unable to derive feature-to-TSS distances from the available data; skipping scatter plot.")
    else:
        analysis_state["feature_importance_distance_column"] = distance_col
        valid = importance_df.dropna(subset=["importance", distance_col]).copy()
        if valid.empty:
            print("No features contained both importance scores and TSS distances.")
        else:
            valid["distance_kb"] = valid[distance_col].astype(float) / 1_000.0
            corr_value = valid[["distance_kb", "importance"]].corr(method="spearman").loc["distance_kb", "importance"]
            fig_distance, ax = plt.subplots(figsize=(7.5, 4.8))
            sns.scatterplot(
                data=valid,
                x="distance_kb",
                y="importance",
                s=32,
                alpha=0.6,
                edgecolor="none",
                color="#4daf4a",
                ax=ax,
            )
            sns.regplot(
                data=valid,
                x="distance_kb",
                y="importance",
                scatter=False,
                lowess=True,
                color="#984ea3",
                ax=ax,
            )
            ax.axvline(0.0, color="#999999", linestyle="--", linewidth=1)
            ax.set_xlabel("Distance to TSS (kb)")
            ax.set_ylabel("Feature importance")
            ax.set_title("Feature Importance vs. Distance to TSS | Best Model")
            ax.text(
                0.01,
                0.98,
                f"Spearman r = {corr_value:.3f}",
                transform=ax.transAxes,
                ha="left",
                va="top",
                fontsize=10,
                bbox=dict(boxstyle="round,pad=0.25", facecolor="white", alpha=0.7, edgecolor="#cccccc"),
            )
            sns.despine(fig_distance, left=True, bottom=True)
            plt.tight_layout()
            display(fig_distance)
            plt.close(fig_distance)

In [None]:
importance_df = analysis_state.get("best_model_feature_importances")
if importance_df is None or importance_df.empty:
    print("Feature importance data unavailable; skipping ATAC peak view.")
else:
    feature_series = importance_df["feature"].astype(str)
    if "feature_type" in importance_df.columns:
        atac_mask = importance_df["feature_type"].astype(str).str.contains("atac|peak", case=False, na=False)
    else:
        atac_mask = feature_series.str.contains("peak", case=False, na=False)
    atac_df = importance_df[atac_mask].copy()
    if atac_df.empty:
        print("No ATAC-related feature names detected; adjust filtering logic if alternative naming is used.")
    else:
        top_atac = atac_df.nlargest(min(25, len(atac_df)), "importance")
        fig_atac, ax = plt.subplots(figsize=(9, max(4, 0.35 * len(top_atac))))
        sns.barplot(
            data=top_atac,
            x="importance",
            y="feature",
            palette="crest",
            ax=ax,
        )
        ax.set_title("Top ATAC Feature Importances | Best Model")
        ax.set_xlabel("Importance score")
        ax.set_ylabel("Feature")
        plt.tight_layout()
        display(fig_atac)
        plt.close(fig_atac)

In [None]:
importance_df = analysis_state.get("best_model_feature_importances")
TOP_FEATURE_COUNT = 1_000
if importance_df is None or importance_df.empty:
    print("Feature importance data unavailable; unable to derive top features.")
else:
    top_n = min(TOP_FEATURE_COUNT, len(importance_df))
    top_features_df = importance_df.nlargest(top_n, "importance").reset_index(drop=True)
    analysis_state["best_model_top_features"] = top_features_df
    print(f"Captured top {top_n} features by importance for the best model.")
    display(top_features_df.head(10))

In [None]:
importance_df = analysis_state.get("best_model_feature_importances")
distance_col = analysis_state.get("feature_importance_distance_column")
distance_window_kb = 50
if importance_df is None or importance_df.empty:
    print("Feature importance data unavailable; cannot filter by TSS proximity.")
elif distance_col is None or distance_col not in importance_df.columns:
    print("No inferred distance-to-TSS column available; run the previous distance cell first or provide the column explicitly.")
else:
    window_mask = importance_df[distance_col].abs() <= distance_window_kb * 1_000
    subset_df = importance_df[window_mask].copy()
    if subset_df.empty:
        print(
            f"No features fall within ±{distance_window_kb} kb of the TSS according to column '{distance_col}'."
)
    else:
        subset_df.sort_values("importance", ascending=False, inplace=True)
        analysis_state["best_model_features_within_window"] = subset_df
        print(
            f"Identified {len(subset_df):,} features within ±{distance_window_kb} kb of the TSS (column '{distance_col}')."
)
        display(subset_df.head(10))

In [None]:
best_details = analysis_state.get("best_model_details")
top_features_df = analysis_state.get("best_model_top_features")
subset_df = analysis_state.get("best_model_features_within_window")
importance_path = analysis_state.get("best_model_feature_importances_path")
distance_col = analysis_state.get("feature_importance_distance_column")
distance_window_kb = 50
if best_details is None:
    print("Best model context missing; rerun the first cell in Section 13.")
else:
    run_config_path = best_details["run_dir"] / "run_configuration.json"
    lines = [
        "**Execution scaffolding for reruns**",
        "",
        f"- Baseline config JSON: `{to_relative_path(run_config_path, config.project_root)}`",
        f"- Suggested CLI template: `python -m spear.cli --config-json {to_relative_path(run_config_path, config.project_root)} --models {best_details['model_id']} --run-name {best_details['run_name']}_experimental`",
    ]
    if importance_path is not None:
        lines.append(
            f"- Feature importance table: `{to_relative_path(importance_path, config.project_root)}`"
        )
    if top_features_df is not None and not top_features_df.empty:
        lines.append(
            "- Export `analysis_state[\"best_model_top_features\"]` to CSV/TSV and feed it into your data loader to mimic a top-1k feature run.",
        )
    else:
        lines.append("- Top-1k feature shortlist unavailable; populate feature importances first.")
    if subset_df is not None and not subset_df.empty and distance_col:
        lines.append(
            f"- Features within ±{distance_window_kb} kb (`{distance_col}`) cached in `analysis_state['best_model_features_within_window']`.",
        )
        lines.append(
            f"  Update `TrainingConfig.window_bp` to {distance_window_kb * 1_000:,} (±{distance_window_kb} kb) in the JSON before re-running.",
        )
    else:
        lines.append(
            f"- No cached subset for the ±{distance_window_kb} kb window; rerun the distance cell after providing TSS metadata.",
        )
    lines.append(
        "- Reminder: the notebook does not launch training jobs automatically; please run the CLI in a new terminal or submit via SLURM.",
    )
    display(Markdown("\n".join(lines)))

## 16. Resource Usage Comparison

Summarize SLURM log resource usage across runs.


In [None]:
import warnings
warnings.filterwarnings('ignore')

log_dir = config.project_root / "output" / "logs"
log_paths = sorted(log_dir.glob("spear_cellwise_chunk*.out"))
if not log_paths:
    print("No Slurm logs found; skipping resource usage summary.")
    fig_resource = None
else:
    if "analysis_state" not in globals():
        analysis_state = {}
    run_df = analysis_state.get("run_df")
    display_lookup: dict[str, str] = {}
    model_id_lookup: dict[str, str] = {}
    if isinstance(run_df, pd.DataFrame) and not run_df.empty:
        if "model_display" in run_df.columns and "run_name" in run_df.columns:
            display_lookup = dict(zip(run_df["run_name"], run_df["model_display"]))
        if "model_id" in run_df.columns and "run_name" in run_df.columns:
            model_id_lookup = dict(zip(run_df["run_name"], run_df["model_id"]))

    rss_pattern = re.compile(r"Resource snapshot\s*\|\s*([^|]+)\|\s*rss=([0-9.]+)\s*GiB", re.IGNORECASE)
    rows: list[dict[str, object]] = []
    for path in log_paths:
        try:
            with path.open("r", encoding="utf-8", errors="ignore") as fh:
                lines = fh.readlines()
        except Exception:
            continue

        run_name = None
        max_rss = None
        peak_stage = None
        for line in lines:
            if "run_name=" in line:
                idx = line.find("run_name=")
                token_start = idx + len("run_name=")
                rest = line[token_start:]
                import re
                candidate_match = re.search(r"[\w\-]+", rest)
                if candidate_match:
                    run_name = candidate_match.group()
            match = rss_pattern.search(line)
            if match:
                stage_str = match.group(1).strip()
                rss_val = float(match.group(2))
                if max_rss is None or rss_val > max_rss:
                    max_rss = rss_val
                    peak_stage = stage_str

        if run_name and max_rss is not None:
            display_name = display_lookup.get(run_name, run_name)
            model_id = model_id_lookup.get(run_name, "")
            is_cpu_run = "cpu" in run_name.lower() or "ridge" in run_name.lower() or "xgboost" in run_name.lower()
            device_label = "CPU" if is_cpu_run else "GPU"
            label = f"{display_name}\n{run_name}"
            rows.append({
                "run_label": run_name,
                "model_id": model_id,
                "peak_rss_gib": max_rss,
                "peak_stage": peak_stage,
                "resolved_name": display_name,
                "accelerator": device_label,
                "label": label,
                "log_file": path.name,
            })

    if rows:
        resource_df = pd.DataFrame(rows)
        sorted_df = resource_df.sort_values("peak_rss_gib", ascending=False)
        analysis_state["resource_usage"] = sorted_df
        table = sorted_df.rename(
            columns={
                "resolved_name": "Model label",
                "run_label": "Run name",
                "model_id": "Model ID",
                "accelerator": "Accelerator",
                "peak_rss_gib": "Peak RSS (GiB)",
                "peak_stage": "Peak stage",
                "log_file": "Log file",
            }
)
        display(
            table[
                [
                    "Model label",
                    "Run name",
                    "Model ID",
                    "Accelerator",
                    "Peak RSS (GiB)",
                    "Peak stage",
                    "Log file",
                ]
            ]
        )
        valid = sorted_df.dropna(subset=["peak_rss_gib"]).copy()
        if valid.empty:
            print("Resource snapshots not yet present in the logs; rerun after the jobs emit resource metrics.")
        else:
            color_map = {"CPU": "#4C72B0", "GPU": "#DD8452"}
            colors = valid["accelerator"].map(color_map).fillna("#808080")
            fig_height = max(3.8, 0.4 * len(valid))
            fig_resource, ax = plt.subplots(figsize=(9, fig_height))
            bars = ax.barh(valid["label"], valid["peak_rss_gib"], color=colors)
            ax.invert_yaxis()
            ax.set_xlabel("Peak RSS (GiB)")
            ax.set_ylabel("Model run")
            ax.set_title("Peak Memory Usage Across SPEAR Cell-wise Runs")
            xmax = valid["peak_rss_gib"].max()
            if xmax <= 0:
                xmax = 1.0
            ax.set_xlim(0, xmax * 1.15)
            offset = max(1.0, xmax * 0.03)
            for bar, (_, row) in zip(bars, valid.iterrows()):
                stage_note = ""
                if isinstance(row["peak_stage"], str) and row["peak_stage"]:
                    stage_note = f" @ {row['peak_stage']}"
                ax.text(
                    bar.get_width() + offset,
                    bar.get_y() + bar.get_height() / 2,
                    f"{row['peak_rss_gib']:.1f} GiB{stage_note}",
                    va="center",
                    fontsize=8.5,
                    color="#1f1f1f",
                )
            legend_handles = [
                Patch(color=color_map[key], label=key)
                for key in sorted(valid["accelerator"].dropna().unique())
                if key in color_map
            ]
            if legend_handles:
                ax.legend(handles=legend_handles, title="Accelerator", loc="lower right")
            sns.despine(ax=ax, left=True, bottom=True)
            fig_resource.tight_layout()
            register_figure(FIGURES, "resource_usage_summary", fig_resource)
            display(fig_resource)
            plt.close(fig_resource)

In [None]:
import re
from datetime import datetime

# Summarize per-model runtime + peak resource usage from per-run logs.
analysis_state = globals().get("analysis_state", {})
if not isinstance(analysis_state, dict):
    analysis_state = {}
    globals()["analysis_state"] = analysis_state

log_dir = config.project_root / "output" / "logs"
log_paths = sorted(log_dir.glob("spear_*_compute_*.log"))
if not log_paths:
    print("No model run logs found under output/logs; skipping resource summary.")
else:
    run_df = analysis_state.get("run_df")
    display_lookup: dict[str, str] = {}
    model_lookup: dict[str, str] = {}
    if isinstance(run_df, pd.DataFrame) and not run_df.empty:
        if "model_display" in run_df.columns and "run_name" in run_df.columns:
            display_lookup = dict(zip(run_df["run_name"], run_df["model_display"]))
        if "model_id" in run_df.columns and "run_name" in run_df.columns:
            model_lookup = dict(zip(run_df["run_name"], run_df["model_id"]))

    run_name_pattern = re.compile(r"run_name=([\w\-]+)")
    resource_pattern = re.compile(
        r"Resource snapshot\s*\|\s*(?P<context>[^|]+)\|\s*rss=(?P<rss>[0-9.]+)\s*GiB\s*\|\s*cpu%=(?P<cpu>[0-9.]+)",
        re.IGNORECASE,
    )
    timestamp_pattern = re.compile(r"^(\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2},\d{3})")

    rows: list[dict[str, object]] = []
    for path in log_paths:
        try:
            with path.open("r", encoding="utf-8", errors="ignore") as fh:
                lines = fh.readlines()
        except Exception as exc:
            print(f"Skipping {path.name}: failed to read ({exc}).")
            continue

        run_name = None
        max_rss = None
        max_cpu = None
        start_ts = None
        end_ts = None
        end_ts_complete = None
        for line in lines:
            if run_name is None and "run_name=" in line:
                match = run_name_pattern.search(line)
                if match:
                    run_name = match.group(1)
            if start_ts is None:
                ts_match = timestamp_pattern.match(line)
                if ts_match:
                    try:
                        start_ts = datetime.strptime(ts_match.group(1), "%Y-%m-%d %H:%M:%S,%f")
                    except ValueError:
                        start_ts = None
            ts_match = timestamp_pattern.match(line)
            if ts_match:
                try:
                    end_ts = datetime.strptime(ts_match.group(1), "%Y-%m-%d %H:%M:%S,%f")
                except ValueError:
                    pass
            if "RUN_COMPLETE_STATUS" in line and ts_match:
                try:
                    end_ts_complete = datetime.strptime(ts_match.group(1), "%Y-%m-%d %H:%M:%S,%f")
                except ValueError:
                    end_ts_complete = None
            res_match = resource_pattern.search(line)
            if res_match:
                rss_val = float(res_match.group("rss"))
                cpu_val = float(res_match.group("cpu"))
                if max_rss is None or rss_val > max_rss:
                    max_rss = rss_val
                if max_cpu is None or cpu_val > max_cpu:
                    max_cpu = cpu_val

        final_end = end_ts_complete or end_ts
        total_seconds = None
        if start_ts and final_end:
            delta = final_end - start_ts
            total_seconds = delta.total_seconds()

        if run_name:
            display_name = display_lookup.get(run_name, run_name)
            model_id = model_lookup.get(run_name, "")
            rows.append({
                "run_name": run_name,
                "model_id": model_id,
                "display_name": display_name,
                "log_file": path.name,
                "peak_rss_gib": max_rss,
                "peak_cpu_pct": max_cpu,
                "start_time": start_ts,
                "end_time": final_end,
                "runtime_seconds": total_seconds,
            })

    if rows:
        resource_df = pd.DataFrame(rows)
        sorted_df = resource_df.sort_values("runtime_seconds", ascending=False, na_position="last")
        analysis_state["resource_runtime"] = sorted_df

        table = sorted_df.rename(
            columns={
                "display_name": "Model display name",
                "run_name": "Run name",
                "model_id": "Model ID",
                "log_file": "Log file",
                "peak_rss_gib": "Peak RSS (GiB)",
                "peak_cpu_pct": "Peak CPU %",
                "runtime_seconds": "Runtime (s)",
            }
        )
        display(
            table[
                [
                    "Model display name",
                    "Run name",
                    "Model ID",
                    "Peak RSS (GiB)",
                    "Peak CPU %",
                    "Runtime (s)",
                    "Log file",
                ]
            ]
        )

        valid = sorted_df.dropna(subset=["runtime_seconds"]).copy()
        if not valid.empty:
            valid["runtime_hours"] = valid["runtime_seconds"] / 3600.0
            fig_runtime, ax = plt.subplots(figsize=(9, max(3.8, 0.4 * len(valid))))
            ax.barh(valid["display_name"], valid["runtime_hours"], color="#55A868")
            ax.invert_yaxis()
            ax.set_xlabel("Runtime (hours)")
            ax.set_ylabel("Model run")
            ax.set_title("Per-run Total Runtime (log-sourced)")
            xmax = valid["runtime_hours"].max()
            if xmax <= 0:
                xmax = 1.0
            ax.set_xlim(0, xmax * 1.15)
            offset = max(0.1, xmax * 0.03)
            for idx, row in valid.iterrows():
                ax.text(
                    row["runtime_hours"] + offset,
                    row["display_name"],
                    f"{row['runtime_hours']:.2f} h",
                    va="center",
                    fontsize=8.5,
                    color="#1f1f1f",
                )
            sns.despine(ax=ax, left=True, bottom=True)
            fig_runtime.tight_layout()
            register_figure(FIGURES, "runtime_comparison", fig_runtime)
            display(fig_runtime)
            plt.close(fig_runtime)
    else:
        print("No matching run logs found.")

## Appendix: Feature Importance Quick Guide

Extra diagnostics for exploring per-gene feature importance outputs.


### Workflow


In [None]:
from pathlib import Path
import pandas as pd
from IPython.display import display, Image

from spear.visualization import plot_per_gene_feature_panel, plot_cumulative_importance_overlay


In [None]:
# Use the best-performing model from the analysis summary
best_details = analysis_state.get("best_model_details")
if best_details is None:
    raise RuntimeError("Best-model details missing; rerun Sections 4-5 before this cell.")
RUN_DIR = Path(best_details["model_dir"]).resolve()
RUN_DIR


In [None]:
aggregate_path = RUN_DIR / 'feature_importances_mean.csv'
per_gene_path = RUN_DIR / 'feature_importance_per_gene_summary.csv'
if not aggregate_path.exists() or not per_gene_path.exists():
    print('Feature importance files missing; skipping feature importance preview.')
    aggregate_df = pd.DataFrame()
    per_gene_df = pd.DataFrame()
else:
    aggregate_df = pd.read_csv(aggregate_path)
    per_gene_df = pd.read_csv(per_gene_path)
    print(f'Aggregate features: {len(aggregate_df):,}')
    print(f'Genes summarized: {len(per_gene_df):,}')
    display(aggregate_df.head())
    display(per_gene_df.head())


### Dataset-level overlays


In [None]:
overlay_path = RUN_DIR / "feature_importance_distance_overview.png"
scatter_path = RUN_DIR / "feature_importance_vs_tss_distance.png"
for label, path in (('Cumulative overlay', overlay_path), ('Scatter', scatter_path)):
    if path.exists():
        display(Image(filename=str(path)))
    else:
        print(f'No file found for {label}: {path}')

### Per-gene panels


In [None]:
panel_dir = RUN_DIR / "per_gene_panels"
panel_paths = sorted(panel_dir.glob('*.png'))
print(f'Found {len(panel_paths)} per-gene panels')
for path in panel_paths[:4]:
    display(Image(filename=str(path)))

In [None]:
def render_custom_panel(gene_name: str, top_n: int = 12):
    subset = aggregate_df[aggregate_df.get('gene_name') == gene_name]
    if subset.empty:
        raise ValueError(f'No features found for gene {gene_name}')
    out_path = RUN_DIR / 'per_gene_panels' / f'custom_{gene_name}.png'
    plot_per_gene_feature_panel(subset, gene_name, out_path, top_n=top_n)
    display(Image(filename=str(out_path)))

if per_gene_df.empty or 'importance_mean_sum' not in per_gene_df.columns:
    print('No per-gene feature importance data available; skipping example panel.')
else:
    example_gene = per_gene_df.sort_values('importance_mean_sum', ascending=False)['gene'].iloc[0]
    render_custom_panel(example_gene)


In [None]:
def recompute_overlay(max_distance_kb: float | None = None):
    required = {'importance_mean', 'distance_to_tss_kb'}
    if aggregate_df.empty or not required.issubset(aggregate_df.columns):
        print('Feature importance overlay unavailable; missing data.')
        return
    df = aggregate_df.dropna(subset=['importance_mean', 'distance_to_tss_kb'])
    if max_distance_kb is not None:
        df = df[df['distance_to_tss_kb'].abs() <= max_distance_kb]
    out_path = RUN_DIR / 'feature_importance_distance_overlay_custom.png'
    plot_cumulative_importance_overlay(
        df['importance_mean'],
        df['distance_to_tss_kb'],
        out_path,
        'Custom FI cumulative profile'
    )
    display(Image(filename=str(out_path)))

recompute_overlay(20)
