# Xenium analysis
related to Figure.2 & sup fig.2




1. **Major cell-type annotation** (normalization → PCA → optional Harmony → neighbors/UMAP → Leiden → marker scoring → `obs["majortype"]`).
2. **Downstream spatial analysis**:
   - optional merge of compartment-specific subtype annotations into `obs["subtype"]`,
   - cellular neighborhoods (CN) from local composition within a fixed radius,
   - distance-to-target gradients and program scoring for a query population.

## Required input files (relative paths; edit in `CONFIG`)

- Main `AnnData` file (`.h5ad`) with:
  - `obsm["spatial"]`: 2D coordinates (x, y) in **microns**
  - `obs["sample"]`: sample/library identifier
  - `obs["cell_id"]`: cell identifier within sample (if missing, `obs_names` will be used)
  - counts in `.X` or `.layers["counts"]`

- Optional subset `.h5ad` files providing `obs["subtype"]` for selected compartments (used to populate/overwrite the main object's `obs["subtype"]`).

## Outputs (written to `results/xenium_integrated/`)

- `adata/`:
  - `xenium_integrated.h5ad` (major type + subtype + CN + distance bins + scores)
- `annotation/`: cluster markers and major-type plots
- `downstream/`: CN tables/plots, distance-bin tables/plots, program scores and ranked genes


In [None]:
from pathlib import Path

# =========================
# CONFIG (edit paths/keys here)
# =========================
CONFIG = {
    # ---- Input ----
    "INPUT_H5AD": Path("data/QC.no_SOX2OT_EEF1G.h5ad"),

    # Optional: compartment-specific subtype annotation files.
    # These are merged into `obs[SUBTYPE_KEY]` by (sample, cell_id).
    "SUBSET_H5AD": {
        "Hepato": Path("data/Hepato_anno.h5ad"),
        "CAF": Path("data/Fibro_anno.h5ad"),
        "Myeloid": Path("data/Myeloid_anno.h5ad"),
        "T": Path("data/T_anno.h5ad"),
    },
    "SUBSET_PRECEDENCE": ["Hepato", "CAF", "Myeloid", "T"],

    # ---- Output root ----
    "OUTROOT": Path("results/xenium_integrated"),

    # ---- Keys in AnnData ----
    "KEYS": {
        "sample": "sample",
        "cell_id": "cell_id",
        "spatial": "spatial",
        "majortype": "majortype",
        "subtype": "subtype",
    },
    "GROUP_KEY": "group",     # optional; ignored if missing

    # ---- Reproducibility ----
    "SEED": 0,
    "FIG_DPI": 200,

    # ---- Preprocessing / clustering ----
    "DROP_GENES": [],  # ignored if absent
    "N_HVG": 2500,
    "N_PCS": 30,
    "N_NEIGHBORS": 30,
    "LEIDEN_RES": 1.2,

    # Batch integration (Harmony) across samples if >1 sample present
    "INTEGRATION": "harmony",  # {"harmony", "none"}
    "HARMONY_MAX_ITER": 50,

    # Differential expression per cluster
    "DE_METHOD": "wilcoxon",

    # ---- Example sample (for illustrative spatial plots and Xenium Explorer exports) ----
    # If None, the first sample (sorted) will be used.
    "EXAMPLE_SAMPLE": None,

    # ---- Cellular neighborhoods (CN) ----
    "RUN_CN": True,
    "CN": {
        "radius_um": 100.0,
        "n_clusters": 11,
        "standardize_within_sample": True,
        "key_added": "spatial_cn",
        "cn_obs_key": "CN",
    },

    # ---- Distance-to-target gradients (default: CAF → EMT front) ----
    "DIST": {
        "target_labels": ["EMT_FRONT", "EMT+Invasion-high"],
        "query_regex": [r"^iCAF_", r"^myCAF_"],
        "distance_min_um": 0.0,
        "distance_max_um": 200.0,
        "bin_size_um": 10.0,
        "bayes_alpha": 0.5,
    },

    # ---- Gene/program gradients (query cells only) ----
    "GENE_GRADIENT": {
        "genes": ["FAP", "CTHRC1", "SULF1", "PDCD1LG2"],
        "signatures": {
            "ImmuneReg": [
                "IL6","IL11","LIF","CXCL12","CXCL14","CXCL1","CXCL2","CXCL8",
                "CCL2","CCL7","ICAM1","PTGS2","TNFAIP6","SERPINE1","HAS1","HAS2",
                "PDGFRA","IGF1","OSM","PRG4"
            ],
            "ECM": [
                "ACTA2","TAGLN","MYL9","TPM2",
                "COL1A1","COL1A2","COL3A1","COL5A1","COL5A2","COL6A1","COL6A2","COL6A3","COL11A1","COL12A1",
                "FN1","POSTN","SPARC","THBS1","THBS2","LOX","LOXL2","PLOD1","PLOD2","PLOD3","SERPINH1",
                "MMP2","MMP11","MMP14","ITGA11","DDR2","PDGFRB","FBLN1","FBLN2","TNC","ASPN","LUM","DCN","FAP"
            ],
            "Antigen_presentation": [
                "HLA-DRA","HLA-DRB1","HLA-DRB5","HLA-DQA1","HLA-DQB1","HLA-DQA2","HLA-DQB2",
                "HLA-DPA1","HLA-DPB1","HLA-DMA","HLA-DMB","HLA-DOA","HLA-DOB","CD74","CIITA",
                "CTSS","CTSL","CTSB","LGMN","IFI30","LAMP1","LAMP2","RFX5","RFXAP","RFXANK"
            ],
            "Quiescent": [
                "LRAT","RBP1","RELN","LHX2","NGFR","PPARG","GFAP","SYNM","SYNM","SYP",
                "ALDH1A1","ALDH1A2","RDH10","RARB"
            ],
        },
        "score_ctrl_size": 50,
        "score_n_bins": 25,
        "min_cells_per_sample_bin": 20,
        "smooth_min": 0.6,
        "r2_min": 0.3,
        "top_k": 100,
    },


    # ---- scRNA ↔ Xenium subtype mapping via Jaccard (marker-set overlap) ----
    # This computes subtype-to-subtype similarity between a single-cell reference and Xenium
    # by Jaccard overlap of filtered marker-gene sets (see downstream JACCARD section).
    "JACCARD": {
        "RUN": True,

        # Single-cell reference (AnnData with obs[SC_GROUPBY])
        "SC_H5AD": Path("data/sc_reference.h5ad"),
        "SC_GROUPBY": "subtype",

        # Xenium grouping key (usually same as CONFIG["KEYS"]["subtype"])
        "XE_GROUPBY": "subtype",

        # Optional: restrict Xenium cells to specific major-type(s) before mapping
        # Example: ["CAF"] or ["Myeloid"] depending on your majortype naming.
        "XE_MAJORTYPE_IN": None,

        # Optional: drop specific labels prior to DEG / Jaccard
        "DROP_SC_LABELS": [],   
        "DROP_XE_LABELS": [],   

        # Normalization used for DE in this section (copies are made; original objects untouched)
        "TARGET_SUM": 1e4,

        # DEG method
        "METHOD": "wilcoxon",

        # Marker-set definition (filters applied on the DE table + in/out expression fraction)
        "N_TOP_SC": 100,
        "N_TOP_XE": 50,
        "MIN_PCT_EXPR": 0.25,
        "MIN_FC": 1.2,          # fold-change cutoff (approx; assumes log2FC in Scanpy output)
        "MAX_FDR": 0.05,
        "MIN_DELTA_PCT": 0.10,  # pct_in - pct_out
        "MAX_PCT_OUT": None,    # optional, e.g. 0.2

        # Remove common "housekeeping" prefixes from marker sets
        "DROP_PREFIX": ("MT-", "RPS", "RPL"),

        # Optional: drop overly shared genes across subtypes to reduce "generic marker" effects
        # None disables; 0.7 means: drop genes appearing in >70% of subtypes
        "DROP_OVER_SHARED_MAX_FRAC": 0.7,

        # Mapping diagnostics
        "JACCARD_THR": 0.05,
        "MIN_CELLS_PER_GROUP": 3,

        # Output name stem (saved under results/xenium_integrated/downstream/jaccard/)
        "OUT_STEM": "jaccard_sc_vs_xenium_subtype",
    },
}

# =========================
# Output directories
# =========================
OUTROOT = CONFIG["OUTROOT"]
DIRS = {
    "root": OUTROOT,
    "adata": OUTROOT / "adata",
    "annotation": OUTROOT / "annotation",
    "annotation_fig": OUTROOT / "annotation" / "figures",
    "annotation_tbl": OUTROOT / "annotation" / "tables",
    "downstream": OUTROOT / "downstream",
    "downstream_fig": OUTROOT / "downstream" / "figures",
    "downstream_tbl": OUTROOT / "downstream" / "tables",
    "xenium_explorer": OUTROOT / "downstream" / "xenium_explorer",
}

for d in DIRS.values():
    d.mkdir(parents=True, exist_ok=True)

print(f"[CONFIG] INPUT_H5AD={CONFIG['INPUT_H5AD']}")
print(f"[CONFIG] OUTROOT={OUTROOT}")


In [None]:
# =========================
# Imports, seeds, versions
# =========================
import os
import random
import sys
import platform
import warnings
from importlib import metadata as importlib_metadata

import numpy as np
import pandas as pd

import scanpy as sc
import anndata as ad

from scipy import sparse
from scipy.spatial import cKDTree
from scipy.stats import spearmanr
from sklearn.cluster import MiniBatchKMeans

import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns

# Optional (used if present)
try:
    import torch
except Exception:
    torch = None

try:
    import scvi
except Exception:
    scvi = None

# Squidpy is only required if CN is enabled
if bool(CONFIG["RUN_CN"]):
    import squidpy as sq

from scipy.ndimage import (
    gaussian_filter,
    label as nd_label,
    binary_fill_holes,
    distance_transform_edt,
)
from skimage.measure import find_contours

# -------------------------
# Reproducibility
# -------------------------
SEED = int(CONFIG["SEED"])
os.environ["PYTHONHASHSEED"] = str(SEED)
random.seed(SEED)
np.random.seed(SEED)

if torch is not None:
    torch.manual_seed(SEED)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(SEED)
    try:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    except Exception:
        pass

if scvi is not None:
    try:
        scvi.settings.seed = SEED
    except Exception:
        pass

# -------------------------
# Plot defaults (vector-friendly)
# -------------------------
mpl.rcParams.update(
    {
        "pdf.fonttype": 42,
        "ps.fonttype": 42,
        "svg.fonttype": "none",
        "axes.spines.top": False,
        "axes.spines.right": False,
        "legend.frameon": False,
        "figure.dpi": 120,
    }
)
sns.set_style("ticks")
sc.settings.verbosity = 2
sc.settings.set_figure_params(dpi=int(CONFIG["FIG_DPI"]), frameon=False)

# -------------------------
# Print key versions
# -------------------------
def _v(pkg: str) -> str:
    try:
        return importlib_metadata.version(pkg)
    except Exception:
        return "NA"

versions = {
    "python": sys.version.split()[0],
    "platform": platform.platform(),
    "scanpy": _v("scanpy"),
    "anndata": _v("anndata"),
    "numpy": _v("numpy"),
    "pandas": _v("pandas"),
    "scipy": _v("scipy"),
    "scikit-learn": _v("scikit-learn"),
    "matplotlib": _v("matplotlib"),
    "seaborn": _v("seaborn"),
    "squidpy": _v("squidpy"),
    "harmonypy": _v("harmonypy"),
    "scikit-image": _v("scikit-image"),
    "torch": _v("torch"),
    "scvi-tools": _v("scvi-tools"),
}

print("[VERSIONS]")
for k, v in versions.items():
    print(f"  {k:>12}: {v}")


In [None]:
# =========================
# Helper functions
# =========================

from typing import Dict, List, Optional, Sequence, Tuple

import matplotlib.pyplot as plt

# Canonical marker sets for *major* cell types.
# These are intended as a reproducible baseline and may be adjusted for the gene panel.
MAJOR_MARKERS: Dict[str, List[str]] = {
    "T_cell": ["CD3D", "CD3E", "TRAC", "IL7R", "LTB"],
    "NK_cell": ["NKG7", "GNLY", "PRF1", "GZMB"],
    "B_cell": ["MS4A1", "CD79A", "CD74", "HLA-DRA"],
    "Plasma_cell": ["MZB1", "JCHAIN", "XBP1", "SDC1"],
    "Myeloid": ["LST1", "TYROBP", "FCER1G", "LYZ"],
    "Dendritic": ["FCER1A", "CLEC10A", "ITGAX", "LILRA4"],
    "Endothelial": ["PECAM1", "VWF", "KDR", "EMCN"],
    "Fibroblast": ["COL1A1", "COL1A2", "DCN", "LUM", "COL3A1"],
    "Pericyte_SMC": ["RGS5", "PDGFRB", "CSPG4", "ACTA2", "MCAM"],
    "Epithelial": ["EPCAM", "KRT8", "KRT18", "KRT19"],
    "Hepatocyte": ["ALB", "APOA1", "TTR", "FABP1"],
    "Mast": ["TPSAB1", "TPSB2", "KIT"],
    "Erythroid": ["HBB", "HBA1", "HBA2"],
}

def _uppercase_lookup(var_names: Sequence[str]) -> Dict[str, str]:
    """Map uppercase gene symbols -> actual var_names entries (first occurrence wins)."""
    out: Dict[str, str] = {}
    for g in var_names:
        gu = str(g).upper()
        if gu not in out:
            out[gu] = str(g)
    return out

def safe_score_genes(
    adata: sc.AnnData,
    gene_list: Sequence[str],
    score_name: str,
    *,
    use_raw: bool = True,
    min_genes: int = 2,
) -> None:
    """Score a marker set with graceful handling of missing genes."""
    src = adata.raw if (use_raw and adata.raw is not None) else adata
    lookup = _uppercase_lookup(src.var_names)
    present = [lookup[g.upper()] for g in gene_list if g.upper() in lookup]
    if len(present) < min_genes:
        adata.obs[score_name] = np.nan
        return
    sc.tl.score_genes(adata, gene_list=present, score_name=score_name, use_raw=use_raw)

def save_figure(fig: plt.Figure, stem: str) -> Tuple[Path, Path]:
    """Save a figure to both PDF and PNG under FIGDIR, and close it."""
    pdf = FIGDIR / f"{stem}.pdf"
    png = FIGDIR / f"{stem}.png"
    fig.savefig(pdf, bbox_inches="tight")
    fig.savefig(png, bbox_inches="tight")
    plt.close(fig)
    return pdf, png

def _load_subset_labels(path: Path, *, sample_key: str, cell_id_key: str, subtype_key: str) -> pd.DataFrame:
    """Load a subset .h5ad and return a mapping table: (sample, cell_id) -> subtype."""
    sub = sc.read_h5ad(path)
    obs = sub.obs.copy()

    if sample_key not in obs:
        raise KeyError(f"{path}: obs[{sample_key!r}] not found.")
    if subtype_key not in obs:
        raise KeyError(f"{path}: obs[{subtype_key!r}] not found.")

    if cell_id_key in obs:
        obs[cell_id_key] = obs[cell_id_key].astype(str)
    else:
        # Fall back to obs_names if the file does not store `cell_id` explicitly
        obs[cell_id_key] = sub.obs_names.astype(str)

    obs[sample_key] = obs[sample_key].astype(str)
    obs[subtype_key] = obs[subtype_key].astype(str)

    out = obs[[sample_key, cell_id_key, subtype_key]].dropna().drop_duplicates()
    return out

def merge_subtype_from_subsets(
    adata: "sc.AnnData",
    subset_paths: dict,
    precedence: list,
    *,
    sample_key: str,
    cell_id_key: str,
    subtype_key: str,
) -> None:
    """
    Overwrite/construct `adata.obs[subtype_key]` by merging subtype annotations from subset files.

    Precedence is applied in order: later sources overwrite earlier ones only for cells present in that subset file.
    """
    if subtype_key not in adata.obs:
        # Fallback: if a coarse label exists, use it; otherwise start empty.
        if "majortype" in adata.obs:
            adata.obs[subtype_key] = adata.obs["majortype"].astype(str)
        else:
            adata.obs[subtype_key] = pd.NA

    base = adata.obs[subtype_key].astype("string")
    key_series = adata.obs[sample_key].astype(str) + "||" + adata.obs[cell_id_key].astype(str)

    for tag in precedence:
        p = subset_paths.get(tag, None)
        if p is None:
            continue
        if not Path(p).exists():
            warnings.warn(f"[WARN] Subset file not found, skipped: {p}")
            continue

        df = _load_subset_labels(Path(p), sample_key=sample_key, cell_id_key=cell_id_key, subtype_key=subtype_key)
        map_key = df[sample_key].astype(str) + "||" + df[cell_id_key].astype(str)
        mapping = pd.Series(df[subtype_key].values, index=map_key).dropna()
        mapped = key_series.map(mapping)
        mask = mapped.notna()
        if mask.sum() > 0:
            base.loc[mask] = mapped.loc[mask].astype("string")

        print(f"[MERGE] {tag}: updated {int(mask.sum())} cells from {p}")

    adata.obs[subtype_key] = pd.Categorical(base.astype(str))

from scipy import sparse

# ============================================================
# Core helpers
# ============================================================
def gene_1d(
    adata: sc.AnnData,
    gene: str,
    *,
    layer: Optional[str] = None,
    use_raw: bool = False,
) -> np.ndarray:
    """Return a dense 1D vector for a gene.

    - If use_raw=True: uses adata.raw (no layers).
    - Else if layer is None: uses adata.X.
    - Else: uses adata.layers[layer].

    Returns
    -------
    np.ndarray
        Shape (n_cells,), dtype float.
    """
    if use_raw:
        if adata.raw is None:
            raise ValueError("use_raw=True but adata.raw is None.")
        if gene not in adata.raw.var_names:
            raise KeyError(f"Gene '{gene}' not found in adata.raw.var_names.")
        X = adata.raw[:, gene].X
    else:
        if gene not in adata.var_names:
            raise KeyError(f"Gene '{gene}' not found in adata.var_names.")
        if layer is None:
            X = adata[:, gene].X
        else:
            if layer not in adata.layers:
                raise KeyError(f"Layer '{layer}' not found in adata.layers.")
            X = adata[:, gene].layers[layer]

    if sparse.issparse(X):
        return X.toarray().ravel().astype(float)
    return np.asarray(X).ravel().astype(float)


In [None]:
# =========================
# Load, preprocess, clustering, major-type annotation
# =========================
keys = CONFIG["KEYS"]
sample_key = keys["sample"]
cell_id_key = keys["cell_id"]
spatial_key = keys["spatial"]
majortype_key = keys["majortype"]
subtype_key = keys["subtype"]

ANN_FIGDIR = DIRS["annotation_fig"]
ANN_TBLDIR = DIRS["annotation_tbl"]

# save_figure() writes to the global FIGDIR
FIGDIR = ANN_FIGDIR

if not CONFIG["INPUT_H5AD"].exists():
    raise FileNotFoundError(
        f"Missing input file: {CONFIG['INPUT_H5AD']}. "
        "Update CONFIG['INPUT_H5AD'] or place the file under the expected path."
    )

adata = sc.read_h5ad(CONFIG["INPUT_H5AD"])
print(f"[LOAD] AnnData: {adata.n_obs:,} cells × {adata.n_vars:,} genes")

# ---- Required fields for the full pipeline ----
for k in [sample_key]:
    if k not in adata.obs:
        raise KeyError(f"Required obs column missing: {k!r}")
if spatial_key not in adata.obsm:
    raise KeyError(f"Required obsm key missing: {spatial_key!r} (needed for downstream spatial analysis)")

# Ensure `cell_id` is available for robust cross-file joins
if cell_id_key not in adata.obs:
    adata.obs[cell_id_key] = adata.obs_names.astype(str)

# Standardize key dtypes early
adata.obs[sample_key] = adata.obs[sample_key].astype(str)
adata.obs[cell_id_key] = adata.obs[cell_id_key].astype(str)

# Optional gene removal (ignored if absent)
drop = [g for g in CONFIG["DROP_GENES"] if g in adata.var_names]
if drop:
    adata = adata[:, ~adata.var_names.isin(drop)].copy()
    print(f"[PREP] Dropped genes: {drop}")

# Ensure a counts layer is available
if "counts" not in adata.layers:
    adata.layers["counts"] = adata.X.copy()

# QC metrics (kept minimal; upstream QC/filtering is assumed)
adata.var["mt"] = adata.var_names.str.upper().str.startswith("MT-")
sc.pp.calculate_qc_metrics(adata, qc_vars=["mt"], inplace=True)

# Normalize + log1p (working matrix in .X)
adata.X = adata.layers["counts"].copy()
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)

# Keep a log1p layer and preserve full-gene log1p as .raw
adata.layers["log1p"] = adata.X.copy()
adata.raw = adata

batch_key = sample_key if adata.obs[sample_key].nunique() > 1 else None
group_key = CONFIG["GROUP_KEY"] if (CONFIG["GROUP_KEY"] in adata.obs.columns) else None

# ------------------- HVGs → PCA → (Harmony) → UMAP → Leiden -------------------
hvg_kwargs = dict(n_top_genes=int(CONFIG["N_HVG"]), flavor="seurat_v3", layer="counts")
if batch_key is not None:
    hvg_kwargs["batch_key"] = batch_key

try:
    sc.pp.highly_variable_genes(adata, **hvg_kwargs)
except ImportError:
    # seurat_v3 requires scikit-misc; fall back to a dependency-free HVG flavor
    hvg_kwargs["flavor"] = "seurat"
    hvg_kwargs.pop("layer", None)
    sc.pp.highly_variable_genes(adata, **hvg_kwargs)

adata = adata[:, adata.var["highly_variable"]].copy()

# Optional regression (only if columns exist)
regress_keys = [k for k in ["total_counts", "pct_counts_mt"] if k in adata.obs.columns]
if regress_keys:
    sc.pp.regress_out(adata, keys=regress_keys)

sc.pp.scale(adata, max_value=10)
sc.tl.pca(adata, n_comps=int(CONFIG["N_PCS"]), svd_solver="arpack", random_state=SEED)

use_rep = "X_pca"
if (
    str(CONFIG["INTEGRATION"]).lower() == "harmony"
    and batch_key is not None
    and adata.obs[batch_key].nunique() > 1
):
    try:
        sc.external.pp.harmony_integrate(
            adata,
            key=batch_key,
            basis="X_pca",
            adjusted_basis="X_pca_harmony",
            max_iter_harmony=int(CONFIG["HARMONY_MAX_ITER"]),
        )
        use_rep = "X_pca_harmony"
    except Exception as e:
        warnings.warn(f"[WARN] Harmony integration failed ({type(e).__name__}: {e}). Falling back to unintegrated PCA.")
        use_rep = "X_pca"

sc.pp.neighbors(
    adata,
    n_neighbors=int(CONFIG["N_NEIGHBORS"]),
    n_pcs=int(CONFIG["N_PCS"]),
    use_rep=use_rep,
    random_state=SEED,
)
sc.tl.umap(adata, random_state=SEED)

cluster_key = f"leiden_r{CONFIG['LEIDEN_RES']}"
sc.tl.leiden(adata, resolution=float(CONFIG["LEIDEN_RES"]), key_added=cluster_key, random_state=SEED)
adata.obs[cluster_key] = adata.obs[cluster_key].astype("category")
print(f"[CLUSTER] {adata.obs[cluster_key].nunique()} clusters in obs['{cluster_key}']")

# ------------------- DE per cluster (markers) -------------------
sc.tl.rank_genes_groups(adata, groupby=cluster_key, method=str(CONFIG["DE_METHOD"]), use_raw=True)
markers_df = sc.get.rank_genes_groups_df(adata, group=None)
markers_csv = ANN_TBLDIR / "cluster_markers_all.csv"
markers_df.to_csv(markers_csv, index=False)
# ------------------- Marker scoring (assist manual annotation) -------------------
score_cols = []
for ct, genes in MAJOR_MARKERS.items():
    col = f"score_{ct}"
    safe_score_genes(adata, genes, score_name=col, use_raw=True)
    score_cols.append(col)

# Aggregate marker scores per cluster (helps manual labeling)
score_by_cluster_csv = ANN_TBLDIR / "marker_scores_by_cluster.csv"
adata.obs[[cluster_key] + score_cols].groupby(cluster_key).mean().to_csv(score_by_cluster_csv)

# Export a template for manual cluster -> majortype annotation
cluster_anno_template_csv = ANN_TBLDIR / "cluster_annotation_template.csv"
pd.DataFrame(
    {
        cluster_key: adata.obs[cluster_key].cat.categories.astype(str),
        majortype_key: "",
    }
).to_csv(cluster_anno_template_csv, index=False)

# ------------------- Manual major-type annotation (NO automatic assignment) -------------------
# Manual step:
# 1) Inspect:
#    - cluster markers table: cluster_markers_all.csv
#    - dotplot of canonical markers by cluster: dotplot_major_markers_by_cluster.(pdf/png)
#    - mean marker-score table by cluster: marker_scores_by_cluster.csv
# 2) Fill CLUSTER2MAJORTYPE mapping below (cluster labels are strings), or
#    edit cluster_annotation_template.csv and load it back (optional).

CLUSTER2MAJORTYPE = {
    # "0": "T",
    # "1": "Macro",
    # "2": "Hepato",
    # ...
}

adata.obs[majortype_key] = (
    adata.obs[cluster_key].astype(str).map(CLUSTER2MAJORTYPE).fillna("Unassigned")
).astype("category")

# Summary tables based on current majortype (may contain many 'Unassigned' until mapping is filled)
ct_counts_csv = ANN_TBLDIR / "majortype_counts.csv"
adata.obs[majortype_key].value_counts().rename("n_cells").to_csv(ct_counts_csv)

cluster_by_ct_csv = ANN_TBLDIR / "cluster_by_majortype.csv"
pd.crosstab(adata.obs[cluster_key], adata.obs[majortype_key]).to_csv(cluster_by_ct_csv)

# ------------------- Key figures -------------------
fig = sc.pl.umap(adata, color=[cluster_key], return_fig=True)
save_figure(fig, "umap_clusters")

fig = sc.pl.umap(
    adata,
    color=[majortype_key]
    + ([batch_key] if batch_key is not None else [])
    + ([group_key] if group_key is not None else []),
    return_fig=True,
)
save_figure(fig, "umap_majortype")

# Dotplot of canonical markers grouped by cluster (for manual annotation)
raw_var = adata.raw.var_names if adata.raw is not None else adata.var_names
lookup = _uppercase_lookup(raw_var)
markers_present = {}
for ct, genes in MAJOR_MARKERS.items():
    present = [lookup[g.upper()] for g in genes if g.upper() in lookup]
    if present:
        markers_present[ct] = present

if markers_present:
    dp = sc.pl.dotplot(
        adata,
        markers_present,
        groupby=cluster_key,
        dendrogram=True,
        standard_scale="var",
        return_fig=True,
    )
    try:
        dp.savefig(ANN_FIGDIR / "dotplot_major_markers_by_cluster.pdf")
        dp.savefig(ANN_FIGDIR / "dotplot_major_markers_by_cluster.png")
    except Exception:
        ax = dp.get_axes()["mainplot_ax"]
        ax.figure.savefig(ANN_FIGDIR / "dotplot_major_markers_by_cluster.pdf", bbox_inches="tight")
        ax.figure.savefig(ANN_FIGDIR / "dotplot_major_markers_by_cluster.png", bbox_inches="tight")
    plt.close("all")


In [None]:
# =========================
# Subtype merge, cellular neighborhoods, distance gradients
# =========================
FIG_DIR = DIRS["downstream_fig"]
TAB_DIR = DIRS["downstream_tbl"]
XE_DIR  = DIRS["xenium_explorer"]

# ---- Minimal sanity checks ----
for k in (sample_key, cell_id_key, majortype_key):
    if k not in adata.obs:
        raise KeyError(f"Required obs column missing: {k!r}")
if spatial_key not in adata.obsm:
    raise KeyError(f"Required obsm key missing: {spatial_key!r}")

adata.obs[sample_key] = adata.obs[sample_key].astype(str)
adata.obs[cell_id_key] = adata.obs[cell_id_key].astype(str)

# =========================
# Construct / merge `subtype`
# =========================
merge_subtype_from_subsets(
    adata,
    subset_paths=CONFIG["SUBSET_H5AD"],
    precedence=CONFIG["SUBSET_PRECEDENCE"],
    sample_key=sample_key,
    cell_id_key=cell_id_key,
    subtype_key=subtype_key,
)

# Example sample for illustrative plots / Xenium Explorer export
if CONFIG["EXAMPLE_SAMPLE"] is None:
    example_sample = sorted(adata.obs[sample_key].astype(str).unique())[0]
else:
    example_sample = str(CONFIG["EXAMPLE_SAMPLE"])
print(f"[DATA] example sample: {example_sample}")

# Xenium Explorer grouping CSV (one example sample)
xe_csv = XE_DIR / f"{example_sample}_groups_subtype.csv"
(
    adata.obs.loc[adata.obs[sample_key] == example_sample, [cell_id_key, subtype_key]]
    .rename(columns={cell_id_key: "cell_id", subtype_key: "group"})
    .to_csv(xe_csv, index=False)
)
print(f"[OK] Xenium Explorer groups CSV: {xe_csv}")

In [None]:
# =========================
# Jaccard: scRNA subtype ↔ Xenium subtype (marker-set overlap)
# =========================
# Core logic (same as your example):
# 1) DE per subtype in scRNA and Xenium (wilcoxon)
# 2) Filter markers by (pct_in, FC, FDR, delta_pct, optional pct_out)
# 3) Jaccard overlap of marker sets -> similarity matrix
# 4) Export matrix + best matches (row-wise best + Hungarian 1-1)

from pathlib import Path
from scipy.optimize import linear_sum_assignment
from scipy.spatial.distance import pdist
from scipy.cluster.hierarchy import linkage, leaves_list

jcfg = CONFIG.get("JACCARD", {})
if not bool(jcfg.get("RUN", False)):
    print("[JACCARD] Skipped (CONFIG['JACCARD']['RUN']=False)")
else:
    sc_path = Path(jcfg.get("SC_H5AD", ""))
    if not sc_path.exists():
        warnings.warn(f"[WARN] SC_H5AD not found: {sc_path}. Skip Jaccard mapping.")
    else:
        # -------------------------
        # 0) Load scRNA reference
        # -------------------------
        ad_sc = sc.read_h5ad(sc_path)
        sc_groupby = str(jcfg.get("SC_GROUPBY", "subtype"))
        xe_groupby = str(jcfg.get("XE_GROUPBY", subtype_key))

        if sc_groupby not in ad_sc.obs:
            raise KeyError(f"[JACCARD] scRNA obs[{sc_groupby!r}] not found in: {sc_path}")
        if xe_groupby not in adata.obs:
            raise KeyError(f"[JACCARD] Xenium obs[{xe_groupby!r}] not found in main adata.")

        # -------------------------
        # 1) Subset Xenium cells (optional majortype filter)
        # -------------------------
        xe_mask = np.ones(adata.n_obs, dtype=bool)

        mt_in = jcfg.get("XE_MAJORTYPE_IN", None)
        if mt_in is not None:
            if majortype_key not in adata.obs:
                raise KeyError(f"[JACCARD] majortype key {majortype_key!r} not in adata.obs; cannot apply XE_MAJORTYPE_IN.")
            mt_in = [str(x) for x in mt_in]
            xe_mask &= adata.obs[majortype_key].astype(str).isin(mt_in).to_numpy()

        # drop NA labels
        xe_mask &= adata.obs[xe_groupby].notna().to_numpy()

        ad_xe = adata[xe_mask].copy()
        print(f"[JACCARD] Xenium subset: {ad_xe.n_obs:,} cells (XE_MAJORTYPE_IN={mt_in})")

        # -------------------------
        # 2) Optional drop labels
        # -------------------------
        drop_sc = set(map(str, jcfg.get("DROP_SC_LABELS", [])))
        drop_xe = set(map(str, jcfg.get("DROP_XE_LABELS", [])))

        if drop_sc:
            ad_sc = ad_sc[~ad_sc.obs[sc_groupby].astype(str).isin(drop_sc)].copy()
        if drop_xe:
            ad_xe = ad_xe[~ad_xe.obs[xe_groupby].astype(str).isin(drop_xe)].copy()

        # remove tiny groups (Scanpy may fail if group size is too small)
        min_cells = int(jcfg.get("MIN_CELLS_PER_GROUP", 3))
        for _ad, _k, _tag in [(ad_sc, sc_groupby, "scRNA"), (ad_xe, xe_groupby, "Xenium")]:
            vc = _ad.obs[_k].astype(str).value_counts()
            bad = vc[vc < min_cells].index.tolist()
            if bad:
                warnings.warn(f"[JACCARD] {_tag}: drop {len(bad)} groups with <{min_cells} cells: {bad[:12]}{'...' if len(bad)>12 else ''}")
                _ad._inplace_subset_obs(~_ad.obs[_k].astype(str).isin(bad))
            # clean categories
            _ad.obs[_k] = _ad.obs[_k].astype("category")
            _ad.obs[_k] = _ad.obs[_k].cat.remove_unused_categories()

        print("[JACCARD] scRNA subtype counts:")
        print(ad_sc.obs[sc_groupby].value_counts())
        print("[JACCARD] Xenium subtype counts:")
        print(ad_xe.obs[xe_groupby].value_counts())

        # -------------------------
        # 3) Prepare log-normalized copies (do NOT modify originals)
        # -------------------------
        def _prep_log1p(a: sc.AnnData, target_sum: float = 1e4) -> sc.AnnData:
            a = a.copy()
            if "counts" not in a.layers:
                a.layers["counts"] = a.X.copy()
            a.X = a.layers["counts"].copy()
            sc.pp.normalize_total(a, target_sum=float(target_sum))
            sc.pp.log1p(a)
            a.raw = a
            return a

        target_sum = float(jcfg.get("TARGET_SUM", 1e4))
        ad_sc2 = _prep_log1p(ad_sc, target_sum=target_sum)
        ad_xe2 = _prep_log1p(ad_xe, target_sum=target_sum)

        # -------------------------
        # 4) Gene intersection (by name)
        # -------------------------
        genes_sc = ad_sc2.raw.var_names if ad_sc2.raw is not None else ad_sc2.var_names
        genes_xe = ad_xe2.raw.var_names if ad_xe2.raw is not None else ad_xe2.var_names
        common_genes = genes_sc.intersection(genes_xe)
        print(f"[JACCARD] common genes: {len(common_genes):,}")
        if len(common_genes) == 0:
            raise ValueError("[JACCARD] No common genes between scRNA and Xenium; check gene naming (symbol vs Ensembl).")

        # -------------------------
        # 5) Rank genes by subtype (store under dedicated keys to avoid overwriting other DE results)
        # -------------------------
        key_sc = "rg_sc_subtype"
        key_xe = "rg_xe_subtype"
        method = str(jcfg.get("METHOD", "wilcoxon"))

        sc.tl.rank_genes_groups(ad_sc2, groupby=sc_groupby, method=method, use_raw=True, key_added=key_sc)
        sc.tl.rank_genes_groups(ad_xe2, groupby=xe_groupby, method=method, use_raw=True, key_added=key_xe)

        # -------------------------
        # 6) Marker sets -> Jaccard matrix
        # -------------------------
        import math
        from collections import Counter

        def _as_upper_set(xs):
            return set([str(x).upper() for x in xs])

        def get_marker_sets_from_rank(
            adata: sc.AnnData,
            groupby: str,
            *,
            key: str,
            n_top: int,
            restrict_genes: pd.Index,
            min_pct_expr: float,
            min_fc: float,
            max_fdr: float,
            min_delta_pct: float,
            max_pct_out: float | None,
            drop_prefix: tuple[str, ...] = ("MT-", "RPS", "RPL"),
        ) -> dict[str, set[str]]:
            # Use RAW for pct calculations (raw uses log1p normalized: log1p(0)=0, so >0 is ok)
            if adata.raw is not None:
                X_use = adata.raw.X
                var_names = pd.Index(adata.raw.var_names.astype(str))
            else:
                X_use = adata.X
                var_names = pd.Index(adata.var_names.astype(str))

            # Uppercase mapping
            var_upper = var_names.str.upper()
            upper_to_name = dict(zip(var_upper, var_names))
            rg = _as_upper_set(restrict_genes)

            # group list
            groups = adata.obs[groupby].astype(str).unique().tolist()
            out: dict[str, set[str]] = {}

            # log2FC threshold
            logfc_thr = math.log2(float(min_fc)) if float(min_fc) > 0 else -np.inf

            for g in groups:
                df = sc.get.rank_genes_groups_df(adata, group=g, key=key).copy()
                if df.empty:
                    out[str(g)] = set()
                    continue

                df["NAME_UPPER"] = df["names"].astype(str).str.upper()
                df = df.set_index("NAME_UPPER", drop=False)

                # pct in/out
                mask_in = (adata.obs[groupby].astype(str).values == str(g))
                n_in = int(mask_in.sum())
                n_out = int((~mask_in).sum())
                if n_in == 0 or n_out == 0:
                    out[str(g)] = set()
                    continue

                if sparse.issparse(X_use):
                    pct_in = np.asarray(X_use[mask_in].getnnz(axis=0)).ravel() / n_in
                    pct_out = np.asarray(X_use[~mask_in].getnnz(axis=0)).ravel() / n_out
                else:
                    pct_in = (X_use[mask_in] > 0).mean(axis=0)
                    pct_out = (X_use[~mask_in] > 0).mean(axis=0)

                pct_in = pd.Series(np.asarray(pct_in).ravel(), index=var_upper)
                pct_out = pd.Series(np.asarray(pct_out).ravel(), index=var_upper)

                # align to df
                common = df.index.intersection(pct_in.index)
                df = df.loc[common].copy()
                df["pct_expr"] = pct_in.loc[common].values
                df["pct_out"] = pct_out.loc[common].values
                df["delta_pct"] = df["pct_expr"] - df["pct_out"]

                # columns existence
                if "logfoldchanges" not in df.columns:
                    df["logfoldchanges"] = np.inf
                if "pvals_adj" not in df.columns:
                    df["pvals_adj"] = 0.0

                # filter
                cond = (
                    (df["pct_expr"] >= float(min_pct_expr)) &
                    (df["logfoldchanges"] >= float(logfc_thr)) &
                    (df["pvals_adj"] <= float(max_fdr)) &
                    (df["delta_pct"] >= float(min_delta_pct))
                )
                if max_pct_out is not None:
                    cond &= (df["pct_out"] <= float(max_pct_out))

                df = df.loc[cond].copy()

                # drop prefixes
                if drop_prefix:
                    badmask = df.index.to_series().str.startswith(tuple([p.upper() for p in drop_prefix]))
                    df = df.loc[~badmask]

                # restrict to common genes (UPPER space)
                df = df.loc[df.index.isin(rg)]

                # take top n
                df = df.sort_values("scores", ascending=False)
                top_upper = df.index.tolist()[: int(n_top)]
                top_names = [upper_to_name.get(u, u) for u in top_upper]
                out[str(g)] = set(map(str, top_names))

            return out

        def drop_over_shared_genes(marker_sets: dict[str, set[str]], max_frac: float = 0.7):
            groups = list(marker_sets.keys())
            n = len(groups)
            cnt = Counter()
            for g in groups:
                cnt.update(marker_sets[g])
            bad = {gene for gene, c in cnt.items() if (c / n) > float(max_frac)}
            new = {g: set([x for x in marker_sets[g] if x not in bad]) for g in groups}
            return new, bad

        def jaccard(a: set[str], b: set[str]) -> float:
            if len(a) == 0 and len(b) == 0:
                return np.nan
            if len(a) == 0 or len(b) == 0:
                return 0.0
            return len(a & b) / len(a | b)

        def jaccard_matrix(markerA: dict[str, set[str]], markerB: dict[str, set[str]]) -> pd.DataFrame:
            rows = list(markerA.keys())
            cols = list(markerB.keys())
            mat = np.zeros((len(rows), len(cols)), dtype=float)
            for i, r in enumerate(rows):
                for j, c in enumerate(cols):
                    mat[i, j] = jaccard(markerA[r], markerB[c])
            return pd.DataFrame(mat, index=rows, columns=cols)

        m_sc = get_marker_sets_from_rank(
            ad_sc2, groupby=sc_groupby, key=key_sc,
            n_top=int(jcfg.get("N_TOP_SC", 100)),
            restrict_genes=common_genes,
            min_pct_expr=float(jcfg.get("MIN_PCT_EXPR", 0.25)),
            min_fc=float(jcfg.get("MIN_FC", 1.2)),
            max_fdr=float(jcfg.get("MAX_FDR", 0.05)),
            min_delta_pct=float(jcfg.get("MIN_DELTA_PCT", 0.10)),
            max_pct_out=jcfg.get("MAX_PCT_OUT", None),
            drop_prefix=tuple(jcfg.get("DROP_PREFIX", ("MT-", "RPS", "RPL"))),
        )
        m_xe = get_marker_sets_from_rank(
            ad_xe2, groupby=xe_groupby, key=key_xe,
            n_top=int(jcfg.get("N_TOP_XE", 50)),
            restrict_genes=common_genes,
            min_pct_expr=float(jcfg.get("MIN_PCT_EXPR", 0.25)),
            min_fc=float(jcfg.get("MIN_FC", 1.2)),
            max_fdr=float(jcfg.get("MAX_FDR", 0.05)),
            min_delta_pct=float(jcfg.get("MIN_DELTA_PCT", 0.10)),
            max_pct_out=jcfg.get("MAX_PCT_OUT", None),
            drop_prefix=tuple(jcfg.get("DROP_PREFIX", ("MT-", "RPS", "RPL"))),
        )

        drop_max_frac = jcfg.get("DROP_OVER_SHARED_MAX_FRAC", None)
        if drop_max_frac is not None:
            m_sc, bad_sc = drop_over_shared_genes(m_sc, max_frac=float(drop_max_frac))
            m_xe, bad_xe = drop_over_shared_genes(m_xe, max_frac=float(drop_max_frac))
            print(f"[JACCARD] dropped shared genes: sc={len(bad_sc)}, xenium={len(bad_xe)} (max_frac={drop_max_frac})")

        J = jaccard_matrix(m_sc, m_xe)

        # -------------------------
        # 7) Save outputs
        # -------------------------
        outstem = str(jcfg.get("OUT_STEM", "jaccard_sc_vs_xenium_subtype"))
        JROOT = DIRS["downstream"] / "jaccard"
        J_FIG = JROOT / "figures"
        J_TBL = JROOT / "tables"
        J_FIG.mkdir(parents=True, exist_ok=True)
        J_TBL.mkdir(parents=True, exist_ok=True)

        j_csv = J_TBL / f"{outstem}.matrix.csv"
        J.to_csv(j_csv)
        print(f"[OK] Jaccard matrix: {j_csv}")

        # best match per sc subtype (row-wise)
        best = pd.DataFrame({
            "sc_subtype": J.index.astype(str),
            "best_xenium_subtype": J.idxmax(axis=1).astype(str).values,
            "jaccard": J.max(axis=1).astype(float).values,
        })
        # margin top1-top2
        X = np.nan_to_num(J.values.astype(float), nan=0.0)
        row_sorted = np.sort(X, axis=1)
        top2 = row_sorted[:, -2] if X.shape[1] >= 2 else np.zeros(X.shape[0])
        best["margin_top1_top2"] = best["jaccard"].values - top2
        best = best.sort_values(["jaccard", "margin_top1_top2"], ascending=False).reset_index(drop=True)

        best_csv = J_TBL / f"{outstem}.best_by_sc_subtype.csv"
        best.to_csv(best_csv, index=False)
        print(f"[OK] best-by-row: {best_csv}")

        # Hungarian 1-to-1 matching
        cost = 1.0 - X
        row_ind, col_ind = linear_sum_assignment(cost)
        match_1to1 = pd.DataFrame({
            "sc_subtype": J.index[row_ind].astype(str),
            "xenium_subtype": J.columns[col_ind].astype(str),
            "jaccard": X[row_ind, col_ind],
        }).sort_values("jaccard", ascending=False).reset_index(drop=True)

        match_csv = J_TBL / f"{outstem}.match_1to1_hungarian.csv"
        match_1to1.to_csv(match_csv, index=False)
        print(f"[OK] hungarian 1-to-1: {match_csv}")

        # -------------------------
        # 8) Heatmap (clustered order, no dendrogram)
        # -------------------------
        X0 = np.nan_to_num(J.values.astype(float), nan=0.0)
        row_dist = pdist(X0, metric="correlation")
        col_dist = pdist(X0.T, metric="correlation")
        if np.isnan(row_dist).any():
            row_dist = pdist(X0, metric="euclidean")
        if np.isnan(col_dist).any():
            col_dist = pdist(X0.T, metric="euclidean")

        row_link = linkage(row_dist, method="average")
        col_link = linkage(col_dist, method="average")
        row_order = leaves_list(row_link)
        col_order = leaves_list(col_link)

        Jc = J.iloc[row_order, col_order]

        fig, ax = plt.subplots(figsize=(0.38 * Jc.shape[1] + 4.5, 0.34 * Jc.shape[0] + 4.0))
        sns.heatmap(Jc, cmap="Reds", vmin=0, vmax=np.nanmax(Jc.values), linewidths=0.2, linecolor="white", ax=ax)
        ax.set_title("Subtype mapping: scRNA vs Xenium (Jaccard)")
        ax.set_xlabel("Xenium subtype")
        ax.set_ylabel("scRNA subtype")
        fig.tight_layout()
        heat_pdf = J_FIG / f"{outstem}.heatmap.pdf"
        heat_png = J_FIG / f"{outstem}.heatmap.png"
        fig.savefig(heat_pdf, bbox_inches="tight")
        fig.savefig(heat_png, bbox_inches="tight")
        plt.close(fig)
        print(f"[OK] heatmap: {heat_pdf}")

        # -------------------------
        # 9) Optional: Scanpy matrixplot of z-scored Jaccard (AI-friendly PDF)
        # -------------------------
        try:
            adataJ = sc.AnnData(
                X=Jc.values.astype(float),
                obs=pd.DataFrame(index=Jc.index.astype(str)),
                var=pd.DataFrame(index=Jc.columns.astype(str)),
            )
            adataJ.obs["subtype"] = pd.Categorical(adataJ.obs_names, categories=adataJ.obs_names, ordered=True)

            Xj = adataJ.X
            Xz = (Xj - Xj.mean(axis=0, keepdims=True)) / (Xj.std(axis=0, ddof=0, keepdims=True) + 1e-9)
            adataJ.layers["z"] = Xz
            _bak = adataJ.X.copy()
            adataJ.X = adataJ.layers["z"]

            sc.tl.dendrogram(adataJ, groupby="subtype")
            v = np.nanpercentile(np.abs(adataJ.X), 98)

            # save via object API to avoid scanpy version differences
            mp = sc.pl.matrixplot(
                adataJ,
                var_names=adataJ.var_names,
                groupby="subtype",
                dendrogram=True,
                cmap="RdBu_r",
                vmin=-v, vmax=v,
                colorbar_title="Z-scaled Jaccard",
                show=False,
                return_fig=True,
            )
            mp_path = J_FIG / f"{outstem}.matrixplot_zJaccard.pdf"
            mp.savefig(mp_path, dpi=300, bbox_inches="tight")
            plt.close("all")

            adataJ.X = _bak
            print(f"[OK] matrixplot: {mp_path}")
        except Exception as e:
            warnings.warn(f"[WARN] Scanpy matrixplot failed ({type(e).__name__}: {e}). Heatmap is still available.")

In [None]:
# =========================
# Cellular neighborhoods (CN) from local cell-type composition
# =========================
if bool(CONFIG["RUN_CN"]):
    cn_cfg = CONFIG["CN"]
    cn_key_added = str(cn_cfg["key_added"])
    cn_obs_key = str(cn_cfg["cn_obs_key"])

    sq.gr.spatial_neighbors(
        adata,
        spatial_key=spatial_key,
        library_key=sample_key,
        coord_type="generic",
        delaunay=False,
        radius=float(cn_cfg["radius_um"]),
        set_diag=True,
        key_added=cn_key_added,
    )

    A = adata.obsp[f"{cn_key_added}_connectivities"].tocsr()
    groups = adata.obs[subtype_key].astype("category")
    group_names = list(groups.cat.categories)

    n_cells = adata.n_obs
    n_groups = len(group_names)

    codes = groups.cat.codes.to_numpy()
    onehot = sparse.csr_matrix(
        (np.ones(n_cells, dtype=np.float32), (np.arange(n_cells), codes)),
        shape=(n_cells, n_groups),
    )

    window_counts = (A @ onehot).astype(np.float32)
    row_sums = np.asarray(window_counts.sum(axis=1)).ravel()
    row_sums[row_sums == 0] = 1.0
    window_frac = window_counts.multiply(1.0 / row_sums[:, None]).toarray().astype(np.float32)

    X = window_frac
    if bool(cn_cfg["standardize_within_sample"]):
        Xz = np.zeros_like(X, dtype=np.float32)
        for s in sorted(adata.obs[sample_key].astype(str).unique()):
            idx = (adata.obs[sample_key].astype(str) == s).to_numpy()
            if idx.sum() < 2:
                continue
            mu = X[idx].mean(axis=0, keepdims=True)
            sd = X[idx].std(axis=0, ddof=0, keepdims=True)
            sd[sd == 0] = 1.0
            Xz[idx] = (X[idx] - mu) / sd
        X = Xz

    km = MiniBatchKMeans(
        n_clusters=int(cn_cfg["n_clusters"]),
        random_state=SEED,
        batch_size=4096,
        n_init=10,
    )
    cn_labels = km.fit_predict(X)
    adata.obs[cn_obs_key] = pd.Categorical([f"CN{c:02d}" for c in cn_labels])

    comp_df = pd.DataFrame(window_frac, columns=group_names, index=adata.obs_names)
    cn_means = comp_df.groupby(adata.obs[cn_obs_key]).mean()
    overall = comp_df.mean(axis=0)
    eps = 1e-6
    log2fc = np.log2((cn_means + eps).div(overall + eps, axis=1))

    cn_table_csv = TAB_DIR / "cn_log2fc_vs_overall.csv"
    log2fc.to_csv(cn_table_csv, index=True)

    cn_labels_csv = TAB_DIR / "cn_labels_per_cell.csv"
    (
        adata.obs[[sample_key, cell_id_key, cn_obs_key]]
        .rename(columns={cn_obs_key: "CN"})
        .to_csv(cn_labels_csv, index=False)
    )

    fig = plt.figure(figsize=(0.35 * log2fc.shape[1] + 6, 0.35 * log2fc.shape[0] + 4))
    sns.heatmap(
        log2fc,
        cmap="coolwarm",
        center=0,
        cbar_kws={"label": "log2 fold-change vs overall"},
        linewidths=0.2,
        linecolor="white",
    )
    plt.title(f"Cellular neighborhoods (radius={cn_cfg['radius_um']} µm)")
    plt.xlabel("Subtype")
    plt.ylabel("CN")
    plt.tight_layout()
    cn_fig = FIG_DIR / "cn_log2fc_heatmap.pdf"
    plt.savefig(cn_fig, bbox_inches="tight")
    plt.close()

    print(f"[OK] CN log2FC table: {cn_table_csv}")
    print(f"[OK] CN labels: {cn_labels_csv}")
    print(f"[OK] CN heatmap: {cn_fig}")
else:
    print("[CN] Skipped (CONFIG['RUN_CN']=False)")

# =========================
# Distance-to-target binning (query cells only)
# =========================
dist_cfg = CONFIG["DIST"]
subtype_str = adata.obs[subtype_key].astype(str)
coords = np.asarray(adata.obsm[spatial_key])[:, :2]
samples = adata.obs[sample_key].astype(str)
cell_ids = adata.obs[cell_id_key].astype(str)

# Pick the first target label that exists in the dataset
target_label = next((t for t in dist_cfg["target_labels"] if (subtype_str == t).any()), None)
if target_label is None:
    raise ValueError(
        "None of CONFIG['DIST']['target_labels'] were found in adata.obs['subtype']. "
        f"Tried: {dist_cfg['target_labels']}"
    )
print(f"[DIST] target_label={target_label!r}")

# Query cells (regex on subtype)
query_mask = pd.Series(False, index=adata.obs_names)
for pat in dist_cfg["query_regex"]:
    query_mask |= subtype_str.str.contains(pat, regex=True, na=False)

target_mask = (subtype_str == target_label)

rows = []
for s in sorted(samples.unique()):
    idx_s = (samples == s).to_numpy()
    idx_t = idx_s & target_mask.to_numpy()
    idx_q = idx_s & query_mask.to_numpy()
    if idx_t.sum() == 0 or idx_q.sum() == 0:
        continue

    tree = cKDTree(coords[idx_t])
    try:
        dists, _ = tree.query(coords[idx_q], k=1, workers=-1)
    except TypeError:
        dists, _ = tree.query(coords[idx_q], k=1)

    rows.append(
        pd.DataFrame(
            {
                sample_key: s,
                cell_id_key: cell_ids[idx_q].values,
                "query_subtype": subtype_str[idx_q].values,
                "dist_um": dists,
            }
        )
    )

if len(rows) == 0:
    raise RuntimeError("No sample had both target cells and query cells; distance computation aborted.")

bin_df = pd.concat(rows, ignore_index=True)
bin_df = bin_df[
    (bin_df["dist_um"] >= float(dist_cfg["distance_min_um"]))
    & (bin_df["dist_um"] <= float(dist_cfg["distance_max_um"]))
].copy()

# Create bins with labels: "0-10", "11-20", ...
dmin = float(dist_cfg["distance_min_um"])
dmax = float(dist_cfg["distance_max_um"])
step = float(dist_cfg["bin_size_um"])

edges = np.arange(dmin, dmax + step, step)
if edges[-1] < dmax:
    edges = np.append(edges, dmax)

if len(edges) < 2:
    raise ValueError("Distance binning edges are invalid; check distance_min/max/bin_size.")

bin_labels = [f"{int(edges[0])}-{int(edges[1])}"]
for i in range(1, len(edges) - 1):
    start = int(edges[i] + 1)
    end = int(edges[i + 1])
    bin_labels.append(f"{start}-{end}")

cats = pd.cut(bin_df["dist_um"], bins=edges, right=True, include_lowest=True)
interval_to_label = dict(zip(cats.cat.categories, bin_labels))
bin_df["bin"] = cats.map(interval_to_label)
bin_df = bin_df.dropna(subset=["bin"]).copy()

dist_table_csv = TAB_DIR / "distance_query_to_target_bins.csv"
bin_df.to_csv(dist_table_csv, index=False)
print(f"[OK] distance table: {dist_table_csv}")

# Add distance info back to adata.obs (query cells only; others remain NA)
adata.obs["dist_um_to_target"] = np.nan
adata.obs["dist_bin_to_target"] = pd.NA

join_key = samples + "||" + cell_ids
bin_key = bin_df[sample_key].astype(str) + "||" + bin_df[cell_id_key].astype(str)
dist_map = pd.Series(bin_df["dist_um"].values, index=bin_key)
bin_map = pd.Series(bin_df["bin"].values, index=bin_key)

mapped_dist = join_key.map(dist_map)
mapped_bin = join_key.map(bin_map)

mask_q = mapped_dist.notna()
adata.obs.loc[mask_q, "dist_um_to_target"] = mapped_dist.loc[mask_q].astype(float).values
adata.obs.loc[mask_q, "dist_bin_to_target"] = pd.Categorical(mapped_bin.loc[mask_q].astype(str))

adata.uns["distance_to_target"] = {
    "target_label": target_label,
    "query_regex": dist_cfg["query_regex"],
    "distance_min_um": dmin,
    "distance_max_um": dmax,
    "bin_size_um": step,
}

# Example spatial visualization (one sample)
ex_mask = (adata.obs[sample_key] == example_sample).to_numpy()
ex_coords = coords[ex_mask]
ex_sub = subtype_str[ex_mask]

ex_target = (ex_sub == target_label).to_numpy()
ex_query = pd.Series(ex_sub).str.contains("|".join(dist_cfg["query_regex"]), regex=True, na=False).to_numpy()

fig, ax = plt.subplots(figsize=(5.5, 5.5))
ax.scatter(ex_coords[ex_target, 0], ex_coords[ex_target, 1], s=2, alpha=0.8, label=f"target: {target_label}")
ax.scatter(ex_coords[ex_query, 0], ex_coords[ex_query, 1], s=2, alpha=0.6, label="query (regex)")
ax.set_aspect("equal")
ax.set_title(f"{example_sample}: target vs query cells")
ax.set_xlabel("x (µm)")
ax.set_ylabel("y (µm)")
ax.legend(markerscale=4)
fig.tight_layout()
fig_path = FIG_DIR / "example_sample_target_vs_query.pdf"
fig.savefig(fig_path, bbox_inches="tight")
plt.close(fig)
print(f"[OK] example spatial plot: {fig_path}")

# =========================
# Distance-binned composition and enrichment (query cells)
# =========================
alpha = float(dist_cfg["bayes_alpha"])

comp = (
    bin_df
    .groupby(["bin", "query_subtype"], observed=True)
    .size()
    .unstack("query_subtype", fill_value=0)
)

def _bin_center(lbl: str) -> float:
    a, b = lbl.split("-")
    return (float(a) + float(b)) / 2.0

bin_order = sorted(comp.index.tolist(), key=_bin_center)
comp = comp.reindex(bin_order)
comp_frac = comp.div(comp.sum(axis=1), axis=0).fillna(0)

comp_by_sample = (
    bin_df
    .groupby([sample_key, "bin", "query_subtype"], observed=True)
    .size()
    .unstack("query_subtype", fill_value=0)
    .reindex(
        pd.MultiIndex.from_product([sorted(adata.obs[sample_key].unique()), bin_order], names=[sample_key, "bin"]),
        fill_value=0,
    )
)

out_prefix = "query_to_target"
comp_csv = TAB_DIR / f"{out_prefix}_composition_counts_over_bins.csv"
frac_csv = TAB_DIR / f"{out_prefix}_composition_fraction_over_bins.csv"
comp_by_sample_csv = TAB_DIR / f"{out_prefix}_composition_counts_by_sample_over_bins.csv"
comp.to_csv(comp_csv)
comp_frac.to_csv(frac_csv)
comp_by_sample.to_csv(comp_by_sample_csv)
print(f"[OK] {comp_csv}")
print(f"[OK] {frac_csv}")
print(f"[OK] {comp_by_sample_csv}")

# Figure: pooled stacked bar (fraction within each bin)
subtype_order = comp.sum(axis=0).sort_values(ascending=False).index.tolist()
frac_use = comp_frac[subtype_order]

if len(subtype_order) <= 20:
    colors = sns.color_palette("tab20", n_colors=len(subtype_order))
else:
    colors = sns.color_palette("husl", n_colors=len(subtype_order))
palette = dict(zip(subtype_order, colors))

fig_w = float(np.clip(0.8 * len(bin_order) + 4.0, 8.0, 22.0))
fig, ax = plt.subplots(figsize=(fig_w, 4.8))
bottom = np.zeros(len(bin_order), dtype=float)
x = np.arange(len(bin_order))

for st in subtype_order:
    vals = frac_use[st].to_numpy()
    ax.bar(x, vals, 0.82, bottom=bottom, label=st, color=palette[st], edgecolor="white", linewidth=0.3)
    bottom += vals

n_total = comp.sum(axis=1).to_numpy()
for i, n in enumerate(n_total):
    ax.text(x[i], 1.01, f"n={int(n)}", ha="center", va="bottom", fontsize=8)

ax.set_xticks(x)
ax.set_xticklabels(bin_order, rotation=0)
ax.set_ylabel("Fraction within bin")
ax.set_xlabel(f"Distance to nearest {target_label} (µm)")
ax.set_ylim(0, 1.08)

legend_cols = min(6, max(3, int(np.ceil(len(subtype_order) / 2))))
ax.legend(ncol=legend_cols, bbox_to_anchor=(0.5, -0.18), loc="upper center", title="Query subtype")

fig.tight_layout()
fig_a = FIG_DIR / f"A_{out_prefix}_stacked_fraction.pdf"
fig.savefig(fig_a, bbox_inches="tight")
plt.close(fig)
print(f"[OK] {fig_a}")

# Figure: log2 enrichment vs sample baseline (bubble heatmap)
total_by_sample_subtype = comp_by_sample.groupby(level=sample_key).sum()
p_base = total_by_sample_subtype.div(total_by_sample_subtype.sum(axis=1), axis=0).replace(0, np.nan)

n_sb = comp_by_sample.sum(axis=1)  # total query cells per (sample, bin)
long_counts = comp_by_sample.stack().rename("y").reset_index()
long_totals = n_sb.rename("n").reset_index()
long = long_counts.merge(long_totals, on=[sample_key, "bin"], how="left")

p_base_long = p_base.stack().rename("p_base").reset_index()
p_base_long = p_base_long.rename(columns={0: "p_base", "level_1": "query_subtype"})
long = long.merge(p_base_long, on=[sample_key, "query_subtype"], how="left")

long["p_hat"] = (long["y"] + alpha) / (long["n"] + 2 * alpha)
long["log2_enrich_vs_sample"] = np.log2((long["p_hat"] / long["p_base"]).replace(0, np.nan))

agg = long.replace([np.inf, -np.inf], np.nan).dropna(subset=["log2_enrich_vs_sample"])
enrich_mean = agg.groupby(["bin", "query_subtype"])["log2_enrich_vs_sample"].mean().unstack(fill_value=0)
count_sum = agg.groupby(["bin", "query_subtype"])["y"].sum().unstack(fill_value=0)

enrich_mean.to_csv(TAB_DIR / f"{out_prefix}_log2_enrichment_vs_sample_baseline_mean.csv")
count_sum.to_csv(TAB_DIR / f"{out_prefix}_counts_for_enrichment_heatmap.csv")

bins = enrich_mean.index.tolist()
subtypes = enrich_mean.columns.tolist()

Xp, Yp, Cc, Ss = [], [], [], []
for i, b in enumerate(bins):
    for j, st in enumerate(subtypes):
        Xp.append(i)
        Yp.append(j)
        Cc.append(enrich_mean.loc[b, st])
        Ss.append(np.log10(count_sum.loc[b, st] + 1.0) * 60.0)

fig, ax = plt.subplots(figsize=(0.55 * len(bins) + 4, 0.35 * len(subtypes) + 3))
sca = ax.scatter(Xp, Yp, c=Cc, s=Ss, cmap="coolwarm", vmin=-2, vmax=2, edgecolors="k", linewidths=0.2)
ax.set_xticks(range(len(bins)))
ax.set_xticklabels(bins, rotation=0)
ax.set_yticks(range(len(subtypes)))
ax.set_yticklabels(subtypes)
ax.set_xlabel(f"Distance to nearest {target_label} (µm)")
ax.set_ylabel("Query subtype")
cb = fig.colorbar(sca, ax=ax, shrink=0.8)
cb.set_label("Mean log2 enrichment vs sample baseline")
fig.tight_layout()
fig_c = FIG_DIR / f"C_{out_prefix}_bubble_enrichment_vs_sample.pdf"
fig.savefig(fig_c, bbox_inches="tight")
plt.close(fig)
print(f"[OK] {fig_c}")

# =========================
# Gene/program gradients along distance bins (query cells)
# =========================
gg_cfg = CONFIG["GENE_GRADIENT"]

if adata.raw is None:
    raise RuntimeError("adata.raw is required for gene/program gradients (full-gene log1p matrix).")

raw_var = adata.raw.var_names
raw_lookup = _uppercase_lookup(raw_var)

X_expr = adata.raw.X
q_mask = adata.obs["dist_bin_to_target"].notna().to_numpy()
if q_mask.sum() == 0:
    raise RuntimeError("No query cells have distance bins assigned; cannot run gene/program gradients.")

q_idx = np.where(q_mask)[0]
q_meta = adata.obs.loc[q_mask, [sample_key, cell_id_key, subtype_key, "dist_bin_to_target"]].copy()
q_meta = q_meta.rename(columns={"dist_bin_to_target": "bin"})
q_meta["bin"] = q_meta["bin"].astype(str)
q_meta = q_meta.reset_index(drop=True)
q_meta["_row_ix"] = q_idx

# ---- 1) Selected genes: mean expression vs distance bin ----
genes_u = [g.upper() for g in gg_cfg["genes"]]
genes_present = [raw_lookup[g] for g in genes_u if g in raw_lookup]
missing_genes = sorted(set(genes_u) - set(raw_lookup.keys()))
if missing_genes:
    warnings.warn(f"[WARN] missing genes skipped: {missing_genes}")

if len(genes_present) > 0:
    gene_ix = [raw_var.get_loc(g) for g in genes_present]
    Xg = X_expr[:, gene_ix]

    rows = []
    for (s, b), df_g in q_meta.groupby([sample_key, "bin"], observed=True):
        if len(df_g) < int(gg_cfg["min_cells_per_sample_bin"]):
            continue
        ix = df_g["_row_ix"].to_numpy()
        mu = np.asarray(Xg[ix].mean(axis=0)).ravel()
        rows.append(pd.Series(mu, index=genes_present, name=(str(s), str(b))))

    if len(rows) > 0:
        gene_means = pd.DataFrame(rows)
        gene_means.index = pd.MultiIndex.from_tuples(gene_means.index, names=[sample_key, "bin"])

        gene_means_csv = TAB_DIR / "selected_genes_mean_expression_by_sample_bin.csv"
        gene_means.to_csv(gene_means_csv)

        mean_by_bin = gene_means.groupby(level="bin").mean().reindex(bin_order)
        sem_by_bin = gene_means.groupby(level="bin").sem(ddof=1).reindex(bin_order)

        x = np.arange(len(mean_by_bin.index))
        fig, ax = plt.subplots(figsize=(8.2, 4.6))
        for g in genes_present:
            y = mean_by_bin[g].to_numpy()
            ax.plot(x, y, label=g)
            lo = y - 2 * sem_by_bin[g].to_numpy()
            hi = y + 2 * sem_by_bin[g].to_numpy()
            ax.fill_between(x, lo, hi, alpha=0.15)

        ax.set_xticks(x)
        ax.set_xticklabels(mean_by_bin.index.tolist(), rotation=0)
        ax.set_xlabel(f"Distance to nearest {target_label} (µm)")
        ax.set_ylabel("Mean log1p-normalized expression (±2 SEM)")
        ax.legend(ncol=2)

        fig.tight_layout()
        fig_path = FIG_DIR / "G_selected_genes_mean_vs_distance.pdf"
        fig.savefig(fig_path, bbox_inches="tight")
        plt.close(fig)

        print(f"[OK] {gene_means_csv}")
        print(f"[OK] {fig_path}")
    else:
        warnings.warn("[WARN] no (sample, bin) groups passed min_cells_per_sample_bin; skipping selected-gene summaries.")
else:
    warnings.warn("[WARN] none of the configured genes are present; skipping selected-gene plots.")

# ---- 2) Signature scores (Scanpy score_genes) ----
sig_scores = []
for name, gene_list in gg_cfg["signatures"].items():
    present = []
    for g in gene_list:
        gu = str(g).upper()
        if gu in raw_lookup:
            present.append(raw_lookup[gu])
    present = list(dict.fromkeys(present))  # keep order, drop duplicates

    if len(present) < 5:
        warnings.warn(f"[WARN] signature {name!r}: too few genes present ({len(present)}); skipped.")
        continue

    score_name = f"sig_{name}"
    sc.tl.score_genes(
        adata,
        gene_list=present,
        score_name=score_name,
        ctrl_size=int(gg_cfg["score_ctrl_size"]),
        n_bins=int(gg_cfg["score_n_bins"]),
        random_state=SEED,
        use_raw=True,
    )
    sig_scores.append(score_name)

if len(sig_scores) > 0:
    score_meta = adata.obs.loc[q_mask, [sample_key, "dist_bin_to_target"] + sig_scores].copy()
    score_meta = score_meta.rename(columns={"dist_bin_to_target": "bin"})
    score_meta["bin"] = score_meta["bin"].astype(str)

    score_means = score_meta.groupby([sample_key, "bin"], observed=True)[sig_scores].mean()
    score_means = score_means.reindex(
        pd.MultiIndex.from_product([sorted(adata.obs[sample_key].unique()), bin_order], names=[sample_key, "bin"])
    )

    score_means_csv = TAB_DIR / "signature_scores_mean_by_sample_bin.csv"
    score_means.to_csv(score_means_csv)

    mean_by_bin = score_means.groupby(level="bin").mean().reindex(bin_order)
    sem_by_bin = score_means.groupby(level="bin").sem(ddof=1).reindex(bin_order)

    x = np.arange(len(mean_by_bin.index))
    fig, ax = plt.subplots(figsize=(8.2, 4.6))
    for sc_name in sig_scores:
        y = mean_by_bin[sc_name].to_numpy()
        ax.plot(x, y, label=sc_name.replace("sig_", ""))
        lo = y - 2 * sem_by_bin[sc_name].to_numpy()
        hi = y + 2 * sem_by_bin[sc_name].to_numpy()
        ax.fill_between(x, lo, hi, alpha=0.15)

    ax.set_xticks(x)
    ax.set_xticklabels(mean_by_bin.index.tolist(), rotation=0)
    ax.set_xlabel(f"Distance to nearest {target_label} (µm)")
    ax.set_ylabel("Mean signature score (±2 SEM)")
    ax.legend(ncol=2)

    fig.tight_layout()
    fig_path = FIG_DIR / "G_signature_scores_mean_vs_distance.pdf"
    fig.savefig(fig_path, bbox_inches="tight")
    plt.close(fig)

    print(f"[OK] {score_means_csv}")
    print(f"[OK] {fig_path}")
else:
    print("[SIG] No signature scores were computed (no gene sets passed the presence threshold).")

# ---- 3) Rank genes by monotonic gradient across bins ----
def rank_gradient_genes(
    adata: "sc.AnnData",
    *,
    q_mask: np.ndarray,
    bin_labels: list,
    sample_key: str,
    min_cells_per_group: int,
    smooth_min: float,
    r2_min: float,
) -> tuple[pd.DataFrame, pd.DataFrame]:
    """Return per-gene gradient metrics and a filtered table."""
    X = adata.raw.X if adata.raw is not None else adata.X
    var_names = adata.raw.var_names if adata.raw is not None else adata.var_names

    q_idx = np.where(q_mask)[0]
    meta = adata.obs.loc[q_mask, [sample_key, "dist_bin_to_target"]].copy()
    meta = meta.rename(columns={"dist_bin_to_target": "bin"})
    meta["bin"] = meta["bin"].astype(str)
    meta["_row_ix"] = q_idx

    centers = np.array([_bin_center(b) for b in bin_labels], dtype=float)

    group_rows = []
    group_index = []
    for (s, b), df_g in meta.groupby([sample_key, "bin"], observed=True):
        if len(df_g) < min_cells_per_group:
            continue
        ix = df_g["_row_ix"].to_numpy()
        mu = np.asarray(X[ix].mean(axis=0)).ravel()
        group_rows.append(mu)
        group_index.append((str(s), str(b)))

    if len(group_rows) == 0:
        raise RuntimeError("No (sample, bin) groups passed min_cells_per_group; cannot rank genes.")

    M = np.vstack(group_rows)
    group_index = pd.MultiIndex.from_tuples(group_index, names=[sample_key, "bin"])
    mean_df = pd.DataFrame(M, index=group_index, columns=var_names)

    bin_means = mean_df.groupby(level="bin").mean().reindex(bin_labels)
    Y = bin_means.to_numpy()

    dy = np.abs(np.diff(Y, axis=0))
    y_range = (Y.max(axis=0) - Y.min(axis=0)) + 1e-9
    smooth = 1.0 - (dy.mean(axis=0) / y_range)

    r = np.array([spearmanr(centers, Y[:, j]).correlation for j in range(Y.shape[1])])
    r = np.nan_to_num(r, nan=0.0)
    r2 = r ** 2

    x0 = centers - centers.mean()
    var_x = (x0 ** 2).sum()
    slope = (x0[:, None] * (Y - Y.mean(axis=0)[None, :])).sum(axis=0) / var_x

    metrics = pd.DataFrame(
        {
            "gene": var_names,
            "smooth": smooth,
            "spearman_r": r,
            "spearman_r2": r2,
            "slope": slope,
        }
    ).sort_values("slope", ascending=False)

    keep = (metrics["smooth"] >= smooth_min) & (metrics["spearman_r2"] >= r2_min)
    return metrics, metrics.loc[keep].copy()

metrics_all, metrics_filt = rank_gradient_genes(
    adata,
    q_mask=q_mask,
    bin_labels=bin_order,
    sample_key=sample_key,
    min_cells_per_group=int(gg_cfg["min_cells_per_sample_bin"]),
    smooth_min=float(gg_cfg["smooth_min"]),
    r2_min=float(gg_cfg["r2_min"]),
)

metrics_csv = TAB_DIR / "gene_gradient_metrics.csv"
metrics_all.to_csv(metrics_csv, index=False)

top_k = int(gg_cfg["top_k"])
top_pos = metrics_filt.sort_values("slope", ascending=False).head(top_k)
top_neg = metrics_filt.sort_values("slope", ascending=True).head(top_k)
top_tbl = pd.concat([top_pos, top_neg], axis=0, ignore_index=True)

top_csv = TAB_DIR / f"ranked_genes_top{top_k}pos_top{top_k}neg.csv"
top_tbl.to_csv(top_csv, index=False)

print(f"[OK] {metrics_csv}")
print(f"[OK] {top_csv}")


# ---- Save final object ----
final_h5ad = DIRS["adata"] / "xenium_integrated.h5ad"
adata.write_h5ad(final_h5ad)
print(f"[OK] wrote: {final_h5ad}")
