# Xenium PDAC external validation (M9)


**Purpose.** Reproduce / validate HCC-derived *spatial structure* signals (tumor–microenvironment interface) in an **independent Xenium PDAC** dataset by computing:

- **Tumor boundary cells** (KNN-based)
- **Distance-to-boundary** metrics (optionally via a smoothed boundary contour)
- **Tumor EMT balance score** (tumor-only; MES vs EPI programs)
- A few **summary plots/tables** for cross-cohort comparison

**Input.**
- `./data/PDAC_Xenium.h5ad` (AnnData; referenced as `PDAC_Xenium.h5ad`)
- *(Optional)* Xenium `cell_boundaries.parquet` or `cell_boundaries.csv.gz` for polygon-level ROI plots

**Outputs.** All figures/tables/updated objects are saved under `./outputs/xenium_pdac/`.

**Notes.**
- Samples are **renamed to `PDAC_P1 ... PDAC_Pn`** for reporting (the raw IDs are preserved in `adata.obs['sample_raw']`).

In [None]:
## =========================
## CONFIG (paths + key params)
## =========================
from pathlib import Path

# Reproducibility
SEED = 550

# I/O
DATA_DIR   = Path("./data")
OUT_DIR    = Path("./outputs") / "xenium_pdac"
FIG_DIR    = OUT_DIR / "figures"
TABLE_DIR  = OUT_DIR / "tables"
OUT_DIR.mkdir(parents=True, exist_ok=True)
FIG_DIR.mkdir(parents=True, exist_ok=True)
TABLE_DIR.mkdir(parents=True, exist_ok=True)

# Input AnnData
H5AD_PATH = DATA_DIR / "PDAC_Xenium.h5ad"

# Column / key names (adjust if your AnnData uses different keys)
SAMPLE_KEY   = "sample"
CELLTYPE_KEY = "majortype"     # broad cell type annotation
TUMOR_LABEL  = "Ductal"        # PDAC tumor epithelial label (modify if needed)
SPATIAL_KEY  = "spatial"       # adata.obsm[SPATIAL_KEY] must exist
COUNTS_LAYER = "counts"        # raw counts layer used for gene signature scoring

# -------------------------
# Boundary detection params
# -------------------------
BOUNDARY_KNN_K = 15            # neighbors for boundary-cell call (per sample)
RUN_SMOOTH_BOUNDARY_FIT = True # if True: fit smoothed boundary curve + compute dist_to_boundary

# Smoothed boundary fitting (KNN classifier on a grid + contour at p=0.5)
CLF_K             = 25         # K for classifier (often larger than BOUNDARY_KNN_K)
GRID_RESOLUTION   = 400        # 300~600; larger = smoother but slower
PADDING_PERCENT   = 0.05
SMOOTH_SIGMA      = 1.6        # gaussian smoothing on probability grid
MAX_DIST_THRESHOLD= 900        # mask grid points far from any cell (same unit as spatial coords; Xenium usually µm)
MIN_VERTEX_COUNT  = 120        # filter tiny contour segments

# Margin bands around the boundary (µm)
MARGIN_STEP_UM  = 50
MARGIN_RANGE_UM = 100          # symmetric range [-range, +range]

# -------------------------
# EMT balance score params
# -------------------------
TARGET_SUM = 1e4               # normalize_total target
MIN_GENES_EACH      = 3
MIN_FRAC_NONZERO    = 0.01
MIN_MEAN_COUNTS     = 0.005
ZSCORE_WITHIN_TUMOR = True

# Candidate gene sets (panel-dependent; will be auto-filtered by coverage + signal)
EPI_PDACHQ = [
    "EPCAM","TACSTD2",
    "KRT8","KRT18","KRT19","KRT7",
    "CLDN4","CLDN3","OCLN","TJP1",
    "MUC1","MSLN",
    "SOX9","KLF5","S100P",
    "EHF","GRHL2","DSP","DSG2","DSC2",
]
MES_PDACHQ = [
    "VIM","ZEB1","ZEB2","SNAI1","SNAI2","TWIST1","TWIST2",
    "FN1","SPARC","TAGLN",
    "ITGA5","ITGAV","ITGB1","ITGB4",
    "CDH2","LGALS3",
    "ANXA1","ANXA2",
    "SERPINE1","TGFBI",
    "MMP14","MMP2","MMP9",
    "CXCR4",
    "S100A4",
]

# -------------------------
# Optional polygon-level ROI plots (Xenium outputs)
# -------------------------
# If provided, should contain cell_boundaries.parquet or cell_boundaries.csv.gz
XENIUM_OUTDIR = None  # e.g. Path("./xenium_output/<sample>/output-...")

# Loupe Browser "Selection" export (CSV). Usually needs header=2 to skip metadata lines.
ROI_CSV = None        # e.g. Path("./data/Selection_1_cells_stats.csv")

# Exports
OUT_H5AD        = OUT_DIR / "PDAC_Xenium.boundary_emt.h5ad"
OUT_SAMPLE_MAP  = TABLE_DIR / "sample_rename_map.csv"
OUT_CELL_COUNTS = TABLE_DIR / "cell_counts_by_sample_and_type.csv"


In [None]:
## =========================
## Reproducibility: seeds + packages + versions
## =========================
import os
import random
import numpy as np
import pandas as pd

import scanpy as sc
import scipy
import sklearn
import matplotlib as mpl
import matplotlib.pyplot as plt
from IPython.display import display

# Seeds
random.seed(SEED)
np.random.seed(SEED)

# Scanpy defaults
sc.settings.verbosity = 2
sc.settings.figdir = str(FIG_DIR)
sc.settings.file_format_figs = "pdf"
sc.set_figure_params(figsize=(5, 5), dpi=120)

# Matplotlib: keep text editable in PDF (Illustrator-friendly)
mpl.rcParams["pdf.fonttype"] = 42
mpl.rcParams["ps.fonttype"]  = 42
mpl.rcParams["svg.fonttype"] = "none"

print("scanpy:", sc.__version__)
print("pandas:", pd.__version__)
print("numpy:", np.__version__)
print("scipy:", scipy.__version__)
print("sklearn:", sklearn.__version__)
print("matplotlib:", mpl.__version__)


## Data loading

- `adata.obs[SAMPLE_KEY]` (sample ID per cell)
- `adata.obs[CELLTYPE_KEY]` (broad cell type; used to define tumor vs non-tumor)
- `adata.obsm[SPATIAL_KEY]` (2D spatial coordinates)
- `adata.layers[COUNTS_LAYER]` (raw counts; used for EMT program scoring)

The code will rename samples to `PDAC_P1 ... PDAC_Pn` for downstream reporting.


In [None]:
## =========================
## Load AnnData + minimal sanity checks
## =========================
from pathlib import Path

if not Path(H5AD_PATH).exists():
    raise FileNotFoundError(f"Missing input: {H5AD_PATH}")

adata = sc.read_h5ad(H5AD_PATH)
print(adata)

# ---- sanity checks ----
for k in [SAMPLE_KEY, CELLTYPE_KEY]:
    if k not in adata.obs.columns:
        raise KeyError(f"adata.obs['{k}'] not found. Available: {list(adata.obs.columns)[:30]} ...")

if SPATIAL_KEY not in adata.obsm:
    raise KeyError(f"adata.obsm['{SPATIAL_KEY}'] not found. Available: {list(adata.obsm.keys())}")

if COUNTS_LAYER not in adata.layers:
    raise KeyError(f"adata.layers['{COUNTS_LAYER}'] not found. Available: {list(adata.layers.keys())}")

# ---- rename samples to PDAC_P1..Pn (keep raw IDs) ----
adata.obs["sample_raw"] = adata.obs[SAMPLE_KEY].astype(str).values
raw_samples = sorted(pd.unique(adata.obs["sample_raw"]))

rename_map = {s: f"PDAC_P{i+1}" for i, s in enumerate(raw_samples)}
adata.obs[SAMPLE_KEY] = adata.obs["sample_raw"].map(rename_map).astype(str)

# stable category order
adata.obs[SAMPLE_KEY] = pd.Categorical(
    adata.obs[SAMPLE_KEY],
    categories=[rename_map[s] for s in raw_samples],
    ordered=True,
)

adata.uns["sample_rename_map"] = rename_map

df_map = pd.DataFrame({"sample_raw": raw_samples, "sample_std": [rename_map[s] for s in raw_samples]})
df_map.to_csv(OUT_SAMPLE_MAP, index=False)
df_map


## Boundary detection & distance-to-boundary

1. **Boundary cells**: tumor cells whose local KNN neighborhood contains non-tumor cells (and vice versa), computed **per sample**.
2. **Distance-to-boundary**: (optional) fit a smooth tumor/non-tumor interface via a KNN classifier over a grid and extract the `p(tumor)=0.5` contour.

The main outputs are stored in:

- `adata.obs['is_boundary']` (bool)
- `adata.obs['dist_to_boundary']` (µm; always non-negative if available)
- `adata.obs['signed_dist_to_boundary']` (tumor-positive / non-tumor-negative if available)
- `adata.obs['margin_band']` (categorical distance bins)


In [None]:
## =========================
## Helpers: boundary cells + smooth boundary + margin bands
## =========================
from sklearn.neighbors import NearestNeighbors, KNeighborsClassifier
from scipy.ndimage import gaussian_filter
from scipy.spatial import cKDTree

def mark_boundary_cells_knn(
    adata,
    sample_key=SAMPLE_KEY,
    obs_col=CELLTYPE_KEY,
    tumor_label=TUMOR_LABEL,
    n_neighbors=BOUNDARY_KNN_K,
    out_col="is_boundary",
    spatial_key=SPATIAL_KEY,
):
    """Mark boundary cells based on KNN neighborhood mixing (per sample).

    Boundary definition:
      - Tumor cell whose k nearest neighbors include any non-tumor cells
      - Non-tumor cell whose k nearest neighbors include any tumor cells

    This is performed **within each sample** to avoid cross-slice neighbor leakage.
    """
    if spatial_key not in adata.obsm:
        raise KeyError(f"adata.obsm['{spatial_key}'] not found")

    X = np.asarray(adata.obsm[spatial_key])
    boundary = np.zeros(adata.n_obs, dtype=bool)

    for s in pd.unique(adata.obs[sample_key].astype(str)):
        m = (adata.obs[sample_key].astype(str).values == str(s))
        Xs = X[m]
        is_tumor = (adata.obs.loc[m, obs_col].astype(str).values == str(tumor_label))

        if Xs.shape[0] <= n_neighbors + 1:
            continue

        nbrs = NearestNeighbors(n_neighbors=n_neighbors + 1, algorithm="kd_tree").fit(Xs)
        _, idx = nbrs.kneighbors(Xs)
        idx = idx[:, 1:]  # drop self

        neigh_is_tumor = is_tumor[idx]
        b = (is_tumor & (~neigh_is_tumor).any(axis=1)) | ((~is_tumor) & (neigh_is_tumor).any(axis=1))
        boundary[m] = b

    adata.obs[out_col] = boundary
    return adata


def fit_boundary_and_distance_per_sample(
    adata,
    sample_key=SAMPLE_KEY,
    obs_col=CELLTYPE_KEY,
    tumor_label=TUMOR_LABEL,
    spatial_key=SPATIAL_KEY,
    # classifier & grid
    clf_k=CLF_K,
    grid_resolution=GRID_RESOLUTION,
    padding_percent=PADDING_PERCENT,
    smooth_sigma=SMOOTH_SIGMA,
    max_dist_threshold=MAX_DIST_THRESHOLD,
    min_vertex_count=MIN_VERTEX_COUNT,
    # outputs
    dist_col="dist_to_boundary",
    signed_dist_col="signed_dist_to_boundary",
    store_key="fitted_boundary_coords_by_sample",
    plot=False,
    plot_max_points=80000,
):
    """Fit a smooth tumor/non-tumor boundary and compute distances (per sample).

    Implementation:
      1) Train KNN classifier: p(tumor | x,y)
      2) Evaluate on a regular grid and smooth (Gaussian)
      3) Extract contour at p=0.5 as boundary polyline(s)
      4) Compute each cell's nearest distance to the boundary polyline points
      5) Signed distance: tumor-positive; non-tumor-negative

    Notes:
      - Units follow `adata.obsm[spatial_key]` (Xenium usually µm)
      - Grid/contour approach is slower but yields a visually smooth boundary.
    """
    if spatial_key not in adata.obsm:
        raise KeyError(f"adata.obsm['{spatial_key}'] not found")

    X = np.asarray(adata.obsm[spatial_key])
    adata.obs[dist_col] = np.nan
    adata.obs[signed_dist_col] = np.nan

    boundary_dict = {}

    for s in pd.unique(adata.obs[sample_key].astype(str)):
        m = (adata.obs[sample_key].astype(str).values == str(s))
        Xs = X[m]
        is_tumor = (adata.obs.loc[m, obs_col].astype(str).values == str(tumor_label)).astype(int)

        if Xs.shape[0] < 2000:
            # too few cells for a stable contour fit; still keep NaN distances
            boundary_dict[str(s)] = []
            continue

        # 1) KNN classifier
        clf = KNeighborsClassifier(n_neighbors=int(clf_k), weights="distance")
        clf.fit(Xs, is_tumor)

        # 2) grid
        x_min, x_max = float(Xs[:, 0].min()), float(Xs[:, 0].max())
        y_min, y_max = float(Xs[:, 1].min()), float(Xs[:, 1].max())
        x_pad = (x_max - x_min) * float(padding_percent)
        y_pad = (y_max - y_min) * float(padding_percent)

        xx, yy = np.meshgrid(
            np.linspace(x_min - x_pad, x_max + x_pad, int(grid_resolution)),
            np.linspace(y_min - y_pad, y_max + y_pad, int(grid_resolution)),
        )
        grid = np.c_[xx.ravel(), yy.ravel()]

        # 3) mask grid points too far from real cells
        geom = NearestNeighbors(n_neighbors=1, algorithm="kd_tree").fit(Xs)
        d, _ = geom.kneighbors(grid)
        d = d.reshape(xx.shape)

        # 4) prob + smoothing
        Z = clf.predict_proba(grid)[:, 1].reshape(xx.shape)
        Z = gaussian_filter(Z, sigma=float(smooth_sigma))
        Z[d > float(max_dist_threshold)] = np.nan

        # 5) contour at 0.5
        import matplotlib.pyplot as _plt
        _fig, _ax = _plt.subplots()
        cs = _ax.contour(xx, yy, Z, levels=[0.5], alpha=0)  # not drawn
        _plt.close(_fig)
        segments = cs.allsegs[0] if len(cs.allsegs) else []
        valid = [seg for seg in segments if seg.shape[0] >= int(min_vertex_count)]
        boundary_dict[str(s)] = valid

        # 6) distances
        if len(valid) > 0:
            boundary_pts = np.vstack(valid)
            tree = cKDTree(boundary_pts)
            dist, _ = tree.query(Xs, k=1)
        else:
            dist = np.full(Xs.shape[0], np.nan)

        adata.obs.loc[m, dist_col] = dist
        signed = np.where(is_tumor.astype(bool), dist, -dist)
        adata.obs.loc[m, signed_dist_col] = signed

        # 7) optional QC plot
        if plot:
            n = Xs.shape[0]
            if n > plot_max_points:
                idx = np.random.choice(n, plot_max_points, replace=False)
            else:
                idx = np.arange(n)

            fig, ax = plt.subplots(figsize=(6, 6))
            ax.scatter(
                Xs[idx, 0], Xs[idx, 1],
                s=1,
                c=np.where(is_tumor[idx] == 1, "#7F1CD3", "#F0A830"),
                linewidth=0,
                alpha=0.6,
            )
            for seg in valid:
                ax.plot(seg[:, 0], seg[:, 1], color="black", lw=2)

            ax.set_title(f"{s} fitted boundary (Tumor={tumor_label})")
            # Xenium often uses image-like coordinates (y increases downward)
            y0, y1 = ax.get_ylim()
            if y0 < y1:
                ax.invert_yaxis()
            plt.show()

    adata.uns[store_key] = boundary_dict
    return adata


def make_margin_bands(
    signed_dist,
    step_um=MARGIN_STEP_UM,
    range_um=MARGIN_RANGE_UM,
):
    """Create categorical boundary bands and a diverging palette centered at 0."""
    step_um = float(step_um)
    range_um = float(range_um)

    bins = np.arange(-range_um, range_um + step_um, step_um)
    labels = [f"({bins[i]},{bins[i+1]}]" for i in range(len(bins) - 1)]

    band = pd.cut(
        signed_dist,
        bins=bins,
        labels=labels,
        include_lowest=True,
    )

    # stable order
    band = band.astype("category")
    band = band.cat.reorder_categories(labels, ordered=True)

    # palette: diverging colormap with vcenter=0
    from matplotlib.colors import TwoSlopeNorm
    mid = (bins[:-1] + bins[1:]) / 2
    norm = TwoSlopeNorm(vmin=float(mid.min()), vcenter=0.0, vmax=float(mid.max()))
    cmap = mpl.cm.get_cmap("RdBu")  # negative blue, positive red
    colors = [mpl.colors.to_hex(cmap(norm(m))) for m in mid]

    # force the band closest to 0 to be white-ish
    zero_idx = int(np.argmin(np.abs(mid)))
    colors[zero_idx] = "#f7f7f7"
    palette = dict(zip(labels, colors))

    return band, palette, {"bins": bins, "labels": labels, "mid": mid}


In [None]:
## =========================
## Run boundary detection + distance metrics
## =========================
# 1) boundary cells (KNN mixing)
adata = mark_boundary_cells_knn(
    adata,
    sample_key=SAMPLE_KEY,
    obs_col=CELLTYPE_KEY,
    tumor_label=TUMOR_LABEL,
    n_neighbors=BOUNDARY_KNN_K,
    out_col="is_boundary",
    spatial_key=SPATIAL_KEY,
)
print(adata.obs["is_boundary"].value_counts(dropna=False))

# 2) a display-only column: only tumor cells are split into boundary / non-boundary
show_col = "boundary_show"
is_tumor = (adata.obs[CELLTYPE_KEY].astype(str) == str(TUMOR_LABEL))
is_b = adata.obs["is_boundary"].astype(bool)

adata.obs[show_col] = "other"
adata.obs.loc[is_tumor & (~is_b), show_col] = "tumor_nonboundary"
adata.obs.loc[is_tumor & (is_b),  show_col] = "tumor_boundary"
adata.obs[show_col] = pd.Categorical(
    adata.obs[show_col],
    categories=["other", "tumor_nonboundary", "tumor_boundary"],
    ordered=True,
)
boundary_palette = {
    "other": "#BDBDBD",
    "tumor_nonboundary": "#1f77b4",
    "tumor_boundary": "#d62728",
}

# QC plot: boundary cells per sample
for s in adata.obs[SAMPLE_KEY].cat.categories:
    ad_sub = adata[adata.obs[SAMPLE_KEY] == s].copy()
    sc.pl.spatial(
        ad_sub,
        color=show_col,
        spot_size=30,
        palette=boundary_palette,
        show=False,
        save=f"_{s}_boundary_show",
    )

# 3) distance-to-boundary (smooth contour fit; optional)
if RUN_SMOOTH_BOUNDARY_FIT:
    adata = fit_boundary_and_distance_per_sample(
        adata,
        sample_key=SAMPLE_KEY,
        obs_col=CELLTYPE_KEY,
        tumor_label=TUMOR_LABEL,
        spatial_key=SPATIAL_KEY,
        clf_k=CLF_K,
        grid_resolution=GRID_RESOLUTION,
        padding_percent=PADDING_PERCENT,
        smooth_sigma=SMOOTH_SIGMA,
        max_dist_threshold=MAX_DIST_THRESHOLD,
        min_vertex_count=MIN_VERTEX_COUNT,
        dist_col="dist_to_boundary",
        signed_dist_col="signed_dist_to_boundary",
        store_key="fitted_boundary_coords_by_sample",
        plot=False,
    )

# 4) margin bands
if "signed_dist_to_boundary" in adata.obs.columns:
    band, band_palette, band_meta = make_margin_bands(
        signed_dist=adata.obs["signed_dist_to_boundary"].astype(float),
        step_um=MARGIN_STEP_UM,
        range_um=MARGIN_RANGE_UM,
    )
    adata.obs["margin_band"] = band
    adata.uns["margin_band_meta"] = {
        "bins": band_meta["bins"].tolist(),
        "labels": list(band_meta["labels"]),
        "mid": band_meta["mid"].tolist(),
        "palette": band_palette,
    }

    for s in adata.obs[SAMPLE_KEY].cat.categories:
        ad_sub = adata[adata.obs[SAMPLE_KEY] == s].copy()
        sc.pl.spatial(
            ad_sub,
            color="margin_band",
            spot_size=30,
            palette=band_palette,
            legend_loc="right margin",
            show=False,
            save=f"_{s}_margin_band",
        )

# 5) save a broad cell-type map per sample (coloring controlled by scanpy)
for s in adata.obs[SAMPLE_KEY].cat.categories:
    ad_sub = adata[adata.obs[SAMPLE_KEY] == s].copy()
    sc.pl.spatial(
        ad_sub,
        color=CELLTYPE_KEY,
        spot_size=30,
        show=False,
        save=f"_{s}_{CELLTYPE_KEY}",
    )

print("Figures saved under:", sc.settings.figdir)


## Tumor EMT balance score (tumor-only)

1. Start from **candidate MES/EPI genes**.
2. Filter by **panel coverage** and minimal **signal** (fraction non-zero + mean counts) within tumor cells.
3. Compute per-cell score:
   - `MES = mean(log1p(norm_counts))` across MES genes
   - `EPI = mean(log1p(norm_counts))` across EPI genes
   - `BAL = z(MES) - z(EPI)` within tumor cells (recommended)

Outputs stored in:
- `adata.obs['EMTbal_Ductal_MES']`
- `adata.obs['EMTbal_Ductal_EPI']`
- `adata.obs['EMTbal_Ductal_BAL']`
- `adata.obs['EMTbal_Ductal_BALratio']`


In [None]:
## =========================
## Helpers: panel-aware EMT balance score (tumor-only)
## =========================
import scipy.sparse as sp

def mean_log1p_norm_from_counts(
    adata,
    genes,
    subset_mask=None,
    layer=COUNTS_LAYER,
    target_sum=TARGET_SUM,
):
    """Per-cell mean(log1p(norm_counts)) across a gene set.

    - Uses `adata.layers[layer]` as counts.
    - Normalizes each cell to `target_sum` then log1p.
    - Returns a vector of length = n_selected_cells (subset_mask) or n_obs.
    """
    genes = [g for g in genes if g in adata.var_names]
    if len(genes) == 0:
        n = int(np.sum(subset_mask)) if subset_mask is not None else adata.n_obs
        return np.full(n, np.nan, dtype=float)

    cell_idx = np.where(subset_mask)[0] if subset_mask is not None else np.arange(adata.n_obs)
    gene_idx = [int(adata.var_names.get_loc(g)) for g in genes]

    Xall = adata.layers[layer]
    Xsub = Xall[cell_idx]
    Xg = Xsub[:, gene_idx]

    lib = np.asarray(Xsub.sum(axis=1)).ravel().astype(float)
    scale = float(target_sum) / (lib + 1e-9)

    if sp.issparse(Xg):
        Xg = Xg.tocsr(copy=True)
        Xg = Xg.multiply(scale[:, None])
        Xg.data = np.log1p(Xg.data)
        y = np.asarray(Xg.mean(axis=1)).ravel().astype(float)
    else:
        Xg = np.asarray(Xg, float) * scale[:, None]
        Xg = np.log1p(Xg)
        y = Xg.mean(axis=1).ravel().astype(float)

    return y


def panel_coverage_report(adata, genes, subset_mask, layer=COUNTS_LAYER):
    """Gene presence + signal summary within a subset (counts-layer based)."""
    cell_idx = np.where(subset_mask)[0]
    Xall = adata.layers[layer][cell_idx]

    rows = []
    n = int(cell_idx.size)

    for g in genes:
        present = (g in adata.var_names)
        if not present:
            rows.append({"gene": g, "present": False, "frac_nonzero": np.nan, "mean_counts": np.nan})
            continue

        j = int(adata.var_names.get_loc(g))
        x = Xall[:, j]

        if sp.issparse(x):
            nnz = x.getnnz()
            frac = float(nnz) / float(n) if n else np.nan
            mean = float(x.mean()) if n else np.nan
        else:
            x = np.asarray(x).ravel()
            frac = float(np.mean(x > 0)) if n else np.nan
            mean = float(np.mean(x)) if n else np.nan

        rows.append({"gene": g, "present": True, "frac_nonzero": frac, "mean_counts": mean})

    return pd.DataFrame(rows)


def pick_genes_by_panel_and_signal(
    adata,
    genes,
    subset_mask,
    layer=COUNTS_LAYER,
    min_frac_nonzero=MIN_FRAC_NONZERO,
    min_mean_counts=MIN_MEAN_COUNTS,
):
    rep = panel_coverage_report(adata, genes, subset_mask=subset_mask, layer=layer)
    rep_ok = rep.query("present == True").copy()
    rep_ok = rep_ok[(rep_ok["frac_nonzero"] >= float(min_frac_nonzero)) & (rep_ok["mean_counts"] >= float(min_mean_counts))]
    picked = rep_ok["gene"].tolist()
    return picked, rep


def add_emt_balance_tumor_only(
    adata,
    mes_candidates,
    epi_candidates,
    out_prefix="EMTbal_Ductal",
    sample_key=SAMPLE_KEY,
    majortype_key=CELLTYPE_KEY,
    tumor_label=TUMOR_LABEL,
    layer=COUNTS_LAYER,
    target_sum=TARGET_SUM,
    min_genes_each=MIN_GENES_EACH,
    min_frac_nonzero=MIN_FRAC_NONZERO,
    min_mean_counts=MIN_MEAN_COUNTS,
    zscore_within_tumor=ZSCORE_WITHIN_TUMOR,
):
    mt = adata.obs[majortype_key].astype(str).values
    is_tumor = (mt == str(tumor_label))
    if is_tumor.sum() == 0:
        raise ValueError(f"No cells with {majortype_key} == '{tumor_label}'")

    mes_use, mes_rep = pick_genes_by_panel_and_signal(
        adata, mes_candidates, subset_mask=is_tumor,
        layer=layer, min_frac_nonzero=min_frac_nonzero, min_mean_counts=min_mean_counts
    )
    epi_use, epi_rep = pick_genes_by_panel_and_signal(
        adata, epi_candidates, subset_mask=is_tumor,
        layer=layer, min_frac_nonzero=min_frac_nonzero, min_mean_counts=min_mean_counts
    )

    print(f"[{out_prefix}] tumor cells: {int(is_tumor.sum())}")
    print(f"[{out_prefix}] MES picked: {len(mes_use)} | {mes_use}")
    print(f"[{out_prefix}] EPI picked: {len(epi_use)} | {epi_use}")

    if len(mes_use) < int(min_genes_each) or len(epi_use) < int(min_genes_each):
        raise ValueError(
            f"Too few usable genes after filtering. MES={len(mes_use)}, EPI={len(epi_use)}.\n"
            f"Suggestion: lower min_frac_nonzero/min_mean_counts or expand candidates."
        )

    mes = mean_log1p_norm_from_counts(adata, mes_use, subset_mask=is_tumor, layer=layer, target_sum=target_sum)
    epi = mean_log1p_norm_from_counts(adata, epi_use, subset_mask=is_tumor, layer=layer, target_sum=target_sum)

    if zscore_within_tumor:
        mes = (mes - np.nanmean(mes)) / (np.nanstd(mes) + 1e-9)
        epi = (epi - np.nanmean(epi)) / (np.nanstd(epi) + 1e-9)

    bal = (mes - epi).astype(float)
    bal_ratio = (mes - epi) / (np.abs(mes) + np.abs(epi) + 1e-9)

    # write back to full-length vectors
    out_mes = np.full(adata.n_obs, np.nan, float)
    out_epi = np.full(adata.n_obs, np.nan, float)
    out_bal = np.full(adata.n_obs, np.nan, float)
    out_bal_ratio = np.full(adata.n_obs, np.nan, float)

    out_mes[is_tumor] = mes
    out_epi[is_tumor] = epi
    out_bal[is_tumor] = bal
    out_bal_ratio[is_tumor] = bal_ratio

    adata.obs[f"{out_prefix}_MES"] = out_mes
    adata.obs[f"{out_prefix}_EPI"] = out_epi
    adata.obs[f"{out_prefix}_BAL"] = out_bal
    adata.obs[f"{out_prefix}_BALratio"] = out_bal_ratio

    v = out_bal[np.isfinite(out_bal)]
    print(f"[{out_prefix}_BAL] n={v.size}, median={np.median(v):.3f}, range=({v.min():.3f},{v.max():.3f})")

    return {
        "mes_use": mes_use, "epi_use": epi_use,
        "mes_report": mes_rep, "epi_report": epi_rep,
    }


In [None]:
## =========================
## Run EMT balance score + basic validation plots
## =========================
emt_info = add_emt_balance_tumor_only(
    adata,
    mes_candidates=MES_PDACHQ,
    epi_candidates=EPI_PDACHQ,
    out_prefix="EMTbal_Ductal",
    majortype_key=CELLTYPE_KEY,
    tumor_label=TUMOR_LABEL,
    layer=COUNTS_LAYER,
    target_sum=TARGET_SUM,
    min_genes_each=MIN_GENES_EACH,
    min_frac_nonzero=MIN_FRAC_NONZERO,
    min_mean_counts=MIN_MEAN_COUNTS,
    zscore_within_tumor=ZSCORE_WITHIN_TUMOR,
)

EMT_COL = "EMTbal_Ductal_BAL"

# Summary table (tumor-only)
is_tumor = (adata.obs[CELLTYPE_KEY].astype(str) == str(TUMOR_LABEL))
df_emt = adata.obs.loc[is_tumor, [SAMPLE_KEY, EMT_COL]].dropna()
df_emt_summary = df_emt.groupby(SAMPLE_KEY)[EMT_COL].describe(percentiles=[0.25, 0.5, 0.75])
df_emt_summary.to_csv(TABLE_DIR / "tumor_emt_summary_by_sample.csv")
df_emt_summary

# Spatial plots per sample (shared color scale)
v = df_emt[EMT_COL].to_numpy(dtype=float)
vmin, vmax = (np.nanpercentile(v, [1, 99]) if v.size else (0.0, 1.0))
print("EMT vmin/vmax (1-99%):", vmin, vmax)

for s in adata.obs[SAMPLE_KEY].cat.categories:
    ad_sub = adata[adata.obs[SAMPLE_KEY] == s].copy()
    sc.pl.spatial(
        ad_sub,
        color=EMT_COL,
        spot_size=20,
        cmap="viridis",
        vmin=vmin, vmax=vmax,
        na_color="#d9d9d9",
        show=False,
        save=f"_{s}_{EMT_COL}",
    )

# EMT vs distance-to-boundary (if available)
if "signed_dist_to_boundary" in adata.obs.columns:
    from scipy.stats import spearmanr

    rows = []
    for s in adata.obs[SAMPLE_KEY].cat.categories:
        m = (
            (adata.obs[SAMPLE_KEY] == s)
            & (adata.obs[CELLTYPE_KEY].astype(str) == str(TUMOR_LABEL))
        )
        df = adata.obs.loc[m, ["signed_dist_to_boundary", EMT_COL]].dropna()
        if df.shape[0] < 50:
            continue
        rho, p = spearmanr(df["signed_dist_to_boundary"].values, df[EMT_COL].values)
        rows.append({"sample": str(s), "spearman_rho": float(rho), "p_value": float(p), "n": int(df.shape[0])})

    df_corr = pd.DataFrame(rows).sort_values("sample")
    df_corr.to_csv(TABLE_DIR / "tumor_emt_vs_signed_dist_spearman.csv", index=False)
    display(df_corr)

    # Binned trend plot (mean EMT per margin band)
    if "margin_band" in adata.obs.columns and "margin_band_meta" in adata.uns:
        band_mid = dict(zip(adata.uns["margin_band_meta"]["labels"], adata.uns["margin_band_meta"]["mid"]))
        df = adata.obs.loc[is_tumor, [SAMPLE_KEY, "margin_band", EMT_COL]].dropna()
        df["band_mid"] = df["margin_band"].astype(str).map(band_mid)

        df_line = (
            df.groupby([SAMPLE_KEY, "band_mid"], as_index=False)[EMT_COL]
              .mean()
              .sort_values(["sample", "band_mid"])
        )
        df_line.to_csv(TABLE_DIR / "tumor_emt_by_margin_band.csv", index=False)

        for s, sub in df_line.groupby(SAMPLE_KEY, sort=False):
            fig, ax = plt.subplots(figsize=(4.5, 3.2))
            ax.plot(sub["band_mid"], sub[EMT_COL], marker="o")
            ax.axvline(0.0, linestyle="--", linewidth=1)
            ax.set_title(f"{s} | {EMT_COL} vs boundary distance")
            ax.set_xlabel("signed distance to boundary (µm)")
            ax.set_ylabel(f"mean {EMT_COL}")
            fig.savefig(FIG_DIR / f"{s}_{EMT_COL}_by_margin_band.pdf", bbox_inches="tight")
            plt.close(fig)

print("Figures saved under:", sc.settings.figdir)


## Results & exports

- `PDAC_Xenium.boundary_emt.h5ad` with all computed columns
- Sample/celltype count tables
- Summary tables produced above (EMT distributions, EMT-vs-boundary correlations, etc.)


In [None]:
## =========================
## Export: tables + updated AnnData
## =========================
# Cell counts by sample × cell type
df_counts = (
    adata.obs
        .groupby([SAMPLE_KEY, CELLTYPE_KEY], observed=True)
        .size()
        .unstack(fill_value=0)
)
df_counts.to_csv(OUT_CELL_COUNTS)
display(df_counts.head())

# Boundary counts (if boundary computed)
if "is_boundary" in adata.obs.columns:
    df_b = (
        adata.obs
            .groupby([SAMPLE_KEY, "is_boundary"], observed=True)
            .size()
            .unstack(fill_value=0)
            .rename(columns={False: "non_boundary", True: "boundary"})
    )
    df_b.to_csv(TABLE_DIR / "boundary_cell_counts_by_sample.csv")
    display(df_b)

# Save updated AnnData
adata.write_h5ad(OUT_H5AD, compression="gzip")
print("Saved:", OUT_H5AD)


# Polygon-level ROI plotting (Xenium cell outlines)

This is only needed if you want **cell-outline** figures (AI-editable PDFs) using the Xenium
`cell_boundaries.parquet` / `cell_boundaries.csv.gz`.

It requires:
- `XENIUM_OUTDIR` pointing to a Xenium `output-.../` folder
- `ROI_CSV` exported from Loupe Browser ("Selection" CSV; typically `header=2`)


In [None]:
## =========================
## Optional: polygon ROI plot utilities (requires Xenium cell_boundaries.*)
## =========================
from pathlib import Path
from matplotlib.collections import PolyCollection
from matplotlib.patches import Rectangle

def read_xenium_cell_boundaries(xenium_outdir: Path) -> pd.DataFrame:
    xenium_outdir = Path(xenium_outdir)
    parq = xenium_outdir / "cell_boundaries.parquet"
    csvg = xenium_outdir / "cell_boundaries.csv.gz"

    if parq.exists():
        bd = pd.read_parquet(parq)
    elif csvg.exists():
        bd = pd.read_csv(csvg)
    else:
        raise FileNotFoundError(
            f"Cannot find cell_boundaries.parquet or cell_boundaries.csv.gz under: {xenium_outdir}"
        )

    needed = {"cell_id", "vertex_x", "vertex_y"}
    if not needed.issubset(set(bd.columns)):
        raise KeyError(f"Boundary table missing columns {needed}. Found: {bd.columns.tolist()}")

    # standardize dtypes
    bd = bd.copy()
    bd["cell_id"] = bd["cell_id"].astype(str)
    return bd


def add_scalebar(ax, length_um=100.0, label=None, loc="lower right", pad_frac=0.05, lw=3.0, fontsize=9):
    """Simple scale bar in data coordinates."""
    x0, x1 = ax.get_xlim()
    y0, y1 = ax.get_ylim()

    dx = (x1 - x0)
    dy = (y1 - y0)

    if loc == "lower right":
        xb = x1 - pad_frac * dx - length_um
        yb = y1 - pad_frac * dy
    elif loc == "lower left":
        xb = x0 + pad_frac * dx
        yb = y1 - pad_frac * dy
    else:
        xb = x1 - pad_frac * dx - length_um
        yb = y1 - pad_frac * dy

    ax.plot([xb, xb + length_um], [yb, yb], color="black", lw=lw, solid_capstyle="butt")
    if label is None:
        label = f"{length_um:g} µm"
    ax.text(xb + length_um / 2, yb - 0.02 * dy, label, ha="center", va="top", fontsize=fontsize)


def plot_xenium_roi_polygons_simple(
    adata,
    bd: pd.DataFrame,
    sample: str,
    roi_cell_ids,
    colorby=CELLTYPE_KEY,
    cell_id_key="cell_id",
    sample_key=SAMPLE_KEY,
    vertex_x="vertex_x",
    vertex_y="vertex_y",
    figsize=(5, 5),
    fill_alpha=0.95,
    edgecolor="none",
    linewidth=0.0,
    invert_y=True,
    overlay_boundary=True,
    boundary_store_key="fitted_boundary_coords_by_sample",
    boundary_lw=1.2,
    boundary_color="black",
    rasterize_polygons=True,
    max_vertices_per_cell=None,
    add_scalebar_flag=True,
    scalebar_um=100,
    export_pdf_path=None,
):
    """Polygon-level ROI plot from Xenium `cell_boundaries.*` + AnnData annotations."""
    obs = adata.obs
    roi_set = set(pd.Series(roi_cell_ids).astype(str).tolist())

    # sample subset
    m_s = obs[sample_key].astype(str).values == str(sample)
    obs_s = obs.loc[m_s].copy()
    if obs_s.shape[0] == 0:
        raise ValueError(f"No cells for sample='{sample}' in adata.obs['{sample_key}'].")

    # ROI subset (within sample)
    if cell_id_key in obs_s.columns:
        cid = obs_s[cell_id_key].astype(str)
    else:
        # fallback: assume obs_names are cell_id
        cid = obs_s.index.astype(str)
        cell_id_key = None

    m_roi = np.isin(cid.values, list(roi_set))
    obs_roi = obs_s.loc[m_roi].copy()
    if obs_roi.shape[0] == 0:
        raise ValueError("ROI is empty after intersecting with adata (check cell_id mapping).")

    # category -> color
    lab = obs_roi[colorby].astype(str).values
    cats = pd.unique(lab).tolist()

    # try scanpy palette first
    cmap = None
    if f"{colorby}_colors" in adata.uns and pd.api.types.is_categorical_dtype(obs[colorby]):
        # scanpy stores colors aligned to categories
        cats_all = list(obs[colorby].cat.categories.astype(str))
        cols_all = list(adata.uns[f"{colorby}_colors"])
        col_map = {c: cols_all[i] for i, c in enumerate(cats_all) if i < len(cols_all)}
    else:
        # fallback: matplotlib tab20
        cm = mpl.cm.get_cmap("tab20", len(cats))
        col_map = {c: mpl.colors.to_hex(cm(i)) for i, c in enumerate(cats)}

    # polygons
    # keep only boundary vertices for ROI cell IDs
    keep_ids = set(obs_roi[cell_id_key].astype(str).tolist()) if cell_id_key is not None else set(obs_roi.index.astype(str))
    bd_sub = bd[bd["cell_id"].astype(str).isin(list(keep_ids))].copy()
    if bd_sub.shape[0] == 0:
        raise ValueError("No matching polygons in bd for ROI cells. Check cell_id dtype / mapping.")

    polys = []
    facecolors = []

    for cid0, g in bd_sub.groupby("cell_id", sort=False):
        xy = g[[vertex_x, vertex_y]].to_numpy(dtype=float)
        if xy.shape[0] >= 2 and np.allclose(xy[0], xy[-1]):
            xy = xy[:-1]
        if xy.shape[0] < 3:
            continue
        if max_vertices_per_cell is not None and xy.shape[0] > int(max_vertices_per_cell):
            # uniform downsample
            idx = np.linspace(0, xy.shape[0] - 1, int(max_vertices_per_cell)).astype(int)
            xy = xy[idx]
        polys.append(xy)

        # lookup label
        if cell_id_key is not None:
            lb = obs_roi.loc[obs_roi[cell_id_key].astype(str) == str(cid0), colorby].astype(str)
        else:
            lb = obs_roi.loc[obs_roi.index.astype(str) == str(cid0), colorby].astype(str)
        lb = lb.iloc[0] if lb.shape[0] else "unknown"
        facecolors.append(col_map.get(str(lb), "#cccccc"))

    if len(polys) == 0:
        raise ValueError("No valid polygons produced (all too small?).")

    fig, ax = plt.subplots(figsize=figsize)
    pc = PolyCollection(
        polys,
        facecolors=facecolors,
        edgecolors=edgecolor,
        linewidths=linewidth,
        alpha=float(fill_alpha),
        rasterized=bool(rasterize_polygons),
    )
    ax.add_collection(pc)

    ax.autoscale()
    ax.set_aspect("equal")
    if invert_y:
        ax.invert_yaxis()

    ax.set_title(f"{sample} | ROI n={len(polys)}")
    ax.set_xticks([]); ax.set_yticks([])
    for spn in ax.spines.values():
        spn.set_visible(False)

    # overlay fitted boundary (if available)
    if overlay_boundary:
        bdic = adata.uns.get(boundary_store_key, {})
        segs = bdic.get(str(sample), []) if isinstance(bdic, dict) else []
        for seg in (segs or []):
            ax.plot(seg[:, 0], seg[:, 1], color=boundary_color, lw=float(boundary_lw), alpha=1.0)

    if add_scalebar_flag:
        add_scalebar(ax, length_um=float(scalebar_um), label=f"{int(scalebar_um)} µm")

    if export_pdf_path is not None:
        fig.savefig(export_pdf_path, format="pdf", bbox_inches="tight")
        print("Saved:", export_pdf_path)

    return fig, ax


# -------------------------
# Example usage (uncomment)
# -------------------------
# if XENIUM_OUTDIR is not None and ROI_CSV is not None:
#     bd = read_xenium_cell_boundaries(XENIUM_OUTDIR)
#     df_roi = pd.read_csv(ROI_CSV, header=2)  # Loupe selection export
#     roi_ids = df_roi["Cell ID"].astype(str).tolist()
#     sample_name = "PDAC_P1"
#     fig, ax = plot_xenium_roi_polygons_simple(
#         adata=adata, bd=bd, sample=sample_name, roi_cell_ids=roi_ids,
#         colorby=CELLTYPE_KEY,
#         export_pdf_path=FIG_DIR / f"{sample_name}_roi_polygons.pdf",
#     )
#     plt.show()


#Side-by-side ROI plot (centroids only; fast)

- coloring by an `obs` column (categorical or numeric)
- coloring by a gene (computed from `adata.layers['counts']` as log1p(normalize_total))
- optional overlay of the fitted boundary curve


In [None]:
## =========================
## Optional: centroid ROI plots (fast; no cell outlines)
## =========================
from matplotlib.colors import Normalize
import scipy.sparse as sp

def gene_log1p_norm_from_counts_idx(
    adata,
    idx,
    gene: str,
    layer=COUNTS_LAYER,
    target_sum=TARGET_SUM,
):
    if gene not in adata.var_names:
        raise KeyError(f"Gene '{gene}' not in adata.var_names")
    j = int(adata.var_names.get_loc(gene))

    Xall = adata.layers[layer]
    Xsub = Xall[idx]
    x = Xsub[:, j]

    lib = np.asarray(Xsub.sum(axis=1)).ravel().astype(float)
    scale = float(target_sum) / (lib + 1e-9)

    if sp.issparse(x):
        x = x.tocsr(copy=True)
        x = x.multiply(scale[:, None])
        x.data = np.log1p(x.data)
        return np.asarray(x).ravel().astype(float)
    else:
        x = np.asarray(x, float).ravel() * scale
        return np.log1p(x).astype(float)


def overlay_boundary_segments(
    ax,
    adata,
    sample: str,
    boundary_store_key="fitted_boundary_coords_by_sample",
    color="black",
    lw=0.8,
    alpha=1.0,
):
    bdic = adata.uns.get(boundary_store_key, {})
    segs = bdic.get(str(sample), []) if isinstance(bdic, dict) else []
    for seg in (segs or []):
        ax.plot(seg[:, 0], seg[:, 1], color=color, lw=float(lw), alpha=float(alpha))


def plot_two_rois_centroids(
    adata,
    sample1: str,
    roi_cell_ids1,
    sample2: str,
    roi_cell_ids2,
    value_key: str,
    value_kind="auto",  # "gene" | "obs" | "auto"
    sample_key=SAMPLE_KEY,
    cell_id_key="cell_id",
    spatial_key=SPATIAL_KEY,
    # visibility filtering
    show_only_key=None,
    show_only_values=None,
    show_only_mode="transparent",  # "transparent" or "filter"
    # boundary band restriction (uses dist_to_boundary if present)
    restrict_to_boundary_band=False,
    boundary_band_um=100.0,
    boundary_dist_col="dist_to_boundary",
    boundary_dist_use_abs=True,
    # plotting
    figsize=(10, 4.5),
    invert_y=True,
    base_alpha=0.25,
    base_size=2.0,
    vis_size=3.0,
    cmap="viridis",
    robust_low=1,
    robust_high=99,
    overlay_boundary=True,
    export_pdf_path=None,
):
    obs = adata.obs
    xy = np.asarray(adata.obsm[spatial_key])

    def _idx(sample, roi_ids):
        m_s = obs[sample_key].astype(str).values == str(sample)

        if cell_id_key in obs.columns:
            cid = obs[cell_id_key].astype(str).values
        else:
            cid = obs.index.astype(str).values

        roi_set = set(pd.Series(roi_ids).astype(str).tolist())
        m_roi = np.isin(cid, list(roi_set))
        idx_roi = np.where(m_s & m_roi)[0]
        if idx_roi.size == 0:
            raise ValueError(f"Empty ROI for sample={sample}. Check ROI ids + cell_id mapping.")

        # visible subset
        idx_vis = idx_roi
        if show_only_key is not None:
            if show_only_values is None:
                raise ValueError("show_only_values must be provided when show_only_key is set.")
            vals = obs.iloc[idx_roi][show_only_key].astype(str).values
            keep = np.isin(vals, [str(show_only_values)] if isinstance(show_only_values, (str,int,float)) else [str(v) for v in show_only_values])
            idx_vis = idx_roi[keep]
            if show_only_mode == "filter":
                idx_roi = idx_vis

        # boundary band restriction
        if restrict_to_boundary_band:
            if boundary_dist_col not in obs.columns:
                raise KeyError(f"{boundary_dist_col} not in adata.obs (run boundary fitting first)")
            dist = pd.to_numeric(obs.iloc[idx_vis][boundary_dist_col], errors="coerce").to_numpy(dtype=float)
            if boundary_dist_use_abs:
                dist = np.abs(dist)
            keep = np.isfinite(dist) & (dist <= float(boundary_band_um))
            idx_vis = idx_vis[keep]

        if idx_vis.size == 0:
            raise ValueError("No visible points after filtering.")

        # plot extent from ROI
        xy_roi = xy[idx_roi]
        pad = 10.0
        xlim = (float(xy_roi[:,0].min()-pad), float(xy_roi[:,0].max()+pad))
        ylim = (float(xy_roi[:,1].min()-pad), float(xy_roi[:,1].max()+pad))

        return idx_roi, idx_vis, xlim, ylim

    idx_roi1, idx_vis1, xlim1, ylim1 = _idx(sample1, roi_cell_ids1)
    idx_roi2, idx_vis2, xlim2, ylim2 = _idx(sample2, roi_cell_ids2)

    # value extraction
    def _values(idx_vis):
        if value_kind == "gene" or (value_kind == "auto" and value_key in adata.var_names):
            v = gene_log1p_norm_from_counts_idx(adata, idx_vis, gene=value_key, layer=COUNTS_LAYER, target_sum=TARGET_SUM)
            kind = "numeric"
        else:
            if value_key not in obs.columns:
                raise KeyError(f"{value_key} not in adata.obs and not a gene")
            s = obs.iloc[idx_vis][value_key]
            if pd.api.types.is_numeric_dtype(s):
                v = pd.to_numeric(s, errors="coerce").to_numpy(dtype=float)
                kind = "numeric"
            else:
                v = s.astype(str).to_numpy()
                kind = "categorical"
        return v, kind

    v1, kind = _values(idx_vis1)
    v2, kind2 = _values(idx_vis2)
    if kind != kind2:
        raise ValueError("value_kind mismatch between ROIs (unexpected).")

    fig, axes = plt.subplots(1, 2, figsize=figsize, constrained_layout=True)
    ax1, ax2 = axes

    # base scatter (all ROI points)
    ax1.scatter(xy[idx_roi1,0], xy[idx_roi1,1], s=base_size, c="#c7c7c7", alpha=base_alpha, linewidth=0)
    ax2.scatter(xy[idx_roi2,0], xy[idx_roi2,1], s=base_size, c="#c7c7c7", alpha=base_alpha, linewidth=0)

    if kind == "numeric":
        v_all = np.concatenate([v1[np.isfinite(v1)], v2[np.isfinite(v2)]])
        if v_all.size == 0:
            vmin, vmax = 0.0, 1.0
        else:
            vmin, vmax = np.nanpercentile(v_all, [float(robust_low), float(robust_high)])
        norm = Normalize(vmin=vmin, vmax=vmax)

        s1 = ax1.scatter(xy[idx_vis1,0], xy[idx_vis1,1], s=vis_size, c=v1, cmap=cmap, norm=norm, linewidth=0)
        s2 = ax2.scatter(xy[idx_vis2,0], xy[idx_vis2,1], s=vis_size, c=v2, cmap=cmap, norm=norm, linewidth=0)

        cbar = fig.colorbar(s2, ax=axes, shrink=0.85)
        cbar.set_label(value_key)
    else:
        cats = sorted(set(v1.tolist() + v2.tolist()))
        cm = mpl.cm.get_cmap("tab20", len(cats))
        col_map = {c: mpl.colors.to_hex(cm(i)) for i, c in enumerate(cats)}

        for ax, idx_vis, vv in [(ax1, idx_vis1, v1), (ax2, idx_vis2, v2)]:
            for c in cats:
                m = (vv == c)
                if not np.any(m):
                    continue
                ax.scatter(xy[idx_vis[m],0], xy[idx_vis[m],1], s=vis_size, c=col_map[c], linewidth=0, label=c)
            ax.legend(bbox_to_anchor=(1.02, 1), loc="upper left", frameon=False)

    for ax, (xlim, ylim), sample in [(ax1, (xlim1, ylim1), sample1), (ax2, (xlim2, ylim2), sample2)]:
        ax.set_aspect("equal")
        ax.set_xlim(xlim); ax.set_ylim(ylim)
        if invert_y:
            ax.invert_yaxis()
        ax.set_xticks([]); ax.set_yticks([])
        for spn in ax.spines.values():
            spn.set_visible(False)
        ax.set_title(str(sample))
        if overlay_boundary:
            overlay_boundary_segments(ax, adata, sample=sample)

    if export_pdf_path is not None:
        fig.savefig(export_pdf_path, format="pdf", bbox_inches="tight")
        print("Saved:", export_pdf_path)

    return fig, axes


# -------------------------
# Example usage (uncomment)
# -------------------------
# df_roi1 = pd.read_csv(ROI_CSV, header=2)
# df_roi2 = pd.read_csv(ROI_CSV, header=2)
# roi1 = df_roi1["Cell ID"].astype(str).tolist()
# roi2 = df_roi2["Cell ID"].astype(str).tolist()
# fig, axes = plot_two_rois_centroids(
#     adata,
#     sample1="PDAC_P1", roi_cell_ids1=roi1,
#     sample2="PDAC_P1", roi_cell_ids2=roi2,
#     value_key="EMTbal_Ductal_BAL",
#     value_kind="obs",
#     show_only_key=CELLTYPE_KEY, show_only_values=TUMOR_LABEL,
#     restrict_to_boundary_band=True, boundary_band_um=100.0,
#     export_pdf_path=FIG_DIR / "roi1_roi2_emt.pdf",
# )
# plt.show()
