# Spatial transcriptomics deconvolution (cell2location)
related to extended fig.2

In [None]:
from pathlib import Path

SEED = 0
DATA_DIR = Path("./data")
OUT_DIR = Path("./outputs/cell2location")

SPATIAL_H5AD = DATA_DIR / "HCC_concat.h5ad"
REF_SIG_H5AD = DATA_DIR / "reference_signatures" / "sc.h5ad"

BATCH_KEY = "sample"
COUNTS_LAYER = "counts"

C2L_MAX_EPOCHS = 30000
C2L_TRAIN_SIZE = 1.0
C2L_NUM_SAMPLES = 1000
C2L_POSTERIOR_BATCH_SIZE = 2500

N_CELLTYPES_PER_PLOT = 6
CELLTYPES_TO_PLOT = None

REF_OUT_DIR = OUT_DIR / "reference_signatures"
SP_OUT_DIR = OUT_DIR / "cell2location_sp"
RES_DIR = OUT_DIR / "results"
PLOT_DIR = OUT_DIR / "plots"

for d in [OUT_DIR, REF_OUT_DIR, SP_OUT_DIR, RES_DIR, PLOT_DIR]:
    d.mkdir(parents=True, exist_ok=True)

print("Inputs:", SPATIAL_H5AD, REF_SIG_H5AD, sep="\n - ")
print("Outputs:", OUT_DIR.resolve(), sep="\n - ")

In [None]:
import os
import random
import numpy as np
import pandas as pd
import scanpy as sc
from importlib.metadata import version, PackageNotFoundError
import sys

try:
    import torch
except Exception:
    torch = None

random.seed(SEED)
np.random.seed(SEED)
if torch:
    torch.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

try:
    import scvi
    scvi.settings.seed = SEED
except Exception:
    scvi = None

def _v(pkg):
    try:
        return version(pkg)
    except PackageNotFoundError:
        return "not-installed"

print("Versions:")
for pkg in ["python", "numpy", "pandas", "scanpy", "anndata", "scvi-tools", "cell2location", "torch"]:
    print(f" - {pkg}: {sys.version.split()[0] if pkg == 'python' else _v(pkg)}")

try:
    import cell2location
    from cell2location.models import Cell2location
    from cell2location.utils import select_slide
    from cell2location.plt import plot_spatial
except Exception as e:
    raise ImportError("cell2location is required.") from e

## Data loading

In [None]:
adata_sp = sc.read_h5ad(SPATIAL_H5AD)
adata_ref = sc.read_h5ad(REF_SIG_H5AD)

print("Spatial:", adata_sp.shape)
print("Reference signatures:", adata_ref.shape)

if COUNTS_LAYER in adata_sp.layers:
    adata_sp.layers[COUNTS_LAYER] = adata_sp.layers[COUNTS_LAYER].copy()
else:
    adata_sp.layers[COUNTS_LAYER] = adata_sp.X.copy()

if BATCH_KEY and BATCH_KEY not in adata_sp.obs.columns:
    print(f"Warning: BATCH_KEY='{BATCH_KEY}' not found; proceeding without.")
    BATCH_KEY = None

if "spatial" not in adata_sp.uns:
    print("Warning: adata_sp.uns['spatial'] missing.")

print("Gene overlap:", len(adata_sp.var_names.intersection(adata_ref.var_names)))

## Core analysis

In [None]:
if "mod" not in adata_ref.uns or "factor_names" not in adata_ref.uns["mod"]:
    raise KeyError("REF_SIG_H5AD must contain factor_names.")
factor_names = list(adata_ref.uns["mod"]["factor_names"])

if "means_per_cluster_mu_fg" in adata_ref.varm:
    mu = adata_ref.varm["means_per_cluster_mu_fg"]
    if isinstance(mu, pd.DataFrame):
        cols = [c for c in mu.columns if any(fn in c for fn in factor_names)]
        if len(cols) == len(factor_names):
            inf_aver = mu[cols].copy()
            inf_aver.columns = factor_names
        else:
            inf_aver = pd.DataFrame(mu.values, index=adata_ref.var_names, columns=factor_names)
    else:
        inf_aver = pd.DataFrame(np.asarray(mu), index=adata_ref.var_names, columns=factor_names)
elif all(f"means_per_cluster_mu_fg_{ct}" in adata_ref.var.columns for ct in factor_names):
    inf_aver = adata_ref.var[[f"means_per_cluster_mu_fg_{ct}" for ct in factor_names]].copy()
    inf_aver.columns = factor_names
else:
    raise KeyError("Cannot find reference signatures.")

common_genes = adata_sp.var_names.intersection(inf_aver.index)
if len(common_genes) < 100:
    raise ValueError(f"Too few shared genes: {len(common_genes)}")

adata_sp = adata_sp[:, common_genes].copy()
inf_aver = inf_aver.loc[common_genes].copy()

print("Aligned:", adata_sp.shape, inf_aver.shape)

sig_csv = REF_OUT_DIR / "inferred_cell_state.csv"
inf_aver.to_csv(sig_csv)
print("Saved signatures:", sig_csv)

In [None]:
Cell2location.setup_anndata(adata_sp, batch_key=BATCH_KEY, layer=COUNTS_LAYER)

mod = Cell2location(adata_sp, cell_state_df=inf_aver, detection_alpha=20)
mod.train(max_epochs=C2L_MAX_EPOCHS, train_size=C2L_TRAIN_SIZE)

adata_sp = mod.export_posterior(
    adata_sp,
    sample_kwargs={"num_samples": C2L_NUM_SAMPLES, "batch_size": C2L_POSTERIOR_BATCH_SIZE},
    use_quantiles=True,
)

mod.save(SP_OUT_DIR.as_posix(), overwrite=True)
sp_h5ad = SP_OUT_DIR / "sp.h5ad"
adata_sp.write(sp_h5ad)
print("Saved spatial model and AnnData:", SP_OUT_DIR)

## Results & exports

In [None]:
if "mod" not in adata_sp.uns or "factor_names" not in adata_sp.uns["mod"]:
    raise KeyError("Expected factor_names in adata_sp.uns['mod'].")

celltypes = list(adata_sp.uns["mod"]["factor_names"])
q05_key = next((k for k in ["q05_cell_abundance_w_sf", "q05_cell_abundance", "q05"] if k in adata_sp.obsm), None)

if q05_key is None:
    raise KeyError("Cannot find q05 abundance.")

q05 = pd.DataFrame(adata_sp.obsm[q05_key], index=adata_sp.obs_names, columns=celltypes)
q05.to_csv(RES_DIR / "cell_abundance_q05.csv")

mean_abund = q05.mean(axis=0).sort_values(ascending=False).to_frame("mean_q05")
mean_abund.to_csv(RES_DIR / "cell_abundance_mean_q05.csv")

for ct in celltypes:
    adata_sp.obs[ct] = q05[ct].values

meta_cols = [c for c in [BATCH_KEY, "sample"] if c and c in adata_sp.obs.columns]
if meta_cols:
    adata_sp.obs[meta_cols].to_csv(RES_DIR / "spot_metadata.csv")

print("Exported abundance and metadata.")

In [None]:
import matplotlib as mpl
import matplotlib.pyplot as plt

def _save_fig(obj, outpath):
    outpath.parent.mkdir(parents=True, exist_ok=True)
    if hasattr(obj, "savefig"):
        obj.savefig(outpath, bbox_inches="tight")
    elif hasattr(obj, "figure") and hasattr(obj.figure, "savefig"):
        obj.figure.savefig(outpath, bbox_inches="tight")
    else:
        plt.savefig(outpath, bbox_inches="tight")
    plt.close("all")

samples = list(pd.unique(adata_sp.obs[BATCH_KEY])) if BATCH_KEY else [None]

for s in samples:
    slide = select_slide(adata_sp, s) if s else adata_sp
    slide_tag = str(s) if s else "all"

    if CELLTYPES_TO_PLOT:
        ct_plot = [ct for ct in CELLTYPES_TO_PLOT if ct in celltypes]
    else:
        q05_slide = pd.DataFrame(slide.obsm[q05_key], index=slide.obs_names, columns=celltypes)
        ct_plot = list(q05_slide.mean(axis=0).sort_values(ascending=False).head(N_CELLTYPES_PER_PLOT).index)

    if not ct_plot: continue

    with mpl.rc_context({"figure.figsize": (6 * min(3, len(ct_plot)), 6 * int(np.ceil(len(ct_plot) / 3)))}):
        fig = plot_spatial(
            adata=slide,
            color=ct_plot,
            labels=ct_plot,
            show_img=True,
            style="fast",
            max_color_quantile=0.992,
            circle_diameter=7,
            colorbar_position="right",
        )
    _save_fig(fig, PLOT_DIR / f"{slide_tag}.cell2location.top{len(ct_plot)}.png")

try:
    fig_hist = mod.plot_history(20)
    _save_fig(fig_hist, PLOT_DIR / "training_history.png")
except Exception:
    pass

print("Plots saved to:", PLOT_DIR.resolve())

# Downstream1
related to extended fig.2

This notebook demonstrates the **shortest reproducible path** from local inputs → spatial niche subtype visualization → subtype-associated composition / marker summaries → saved outputs.


In [None]:
from pathlib import Path
import random
import numpy as np

CONFIG = {
    "sample_id": "P9_T",
    "data_dir": Path("./data"),
    "cellfrac_file": "niche/{sample}_cellfrac.h5ad",
    "gene_file": "niche/{sample}_gene.h5ad",
    "subtype_key": "subtype",
    "seed": 0,
    "umap_n_neighbors": 15,
    "umap_min_dist": 0.5,
    "rank_method": "wilcoxon",
    "hvg_min_mean": 0.0125,
    "hvg_max_mean": 3.0,
    "hvg_min_disp": 0.5,
}

OUTDIR = Path("./outputs") / "st_subtyping_demo" / CONFIG["sample_id"]
FIGDIR = OUTDIR / "figs"
TABDIR = OUTDIR / "tables"
OUTDIR.mkdir(parents=True, exist_ok=True)
FIGDIR.mkdir(parents=True, exist_ok=True)
TABDIR.mkdir(parents=True, exist_ok=True)

random.seed(CONFIG["seed"])
np.random.seed(CONFIG["seed"])
try:
    import torch
    torch.manual_seed(CONFIG["seed"])
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(CONFIG["seed"])
except Exception:
    pass

import pandas as pd
import scanpy as sc
import matplotlib as mpl
import matplotlib.pyplot as plt

sc.settings.verbosity = 2
sc.set_figure_params(dpi=120, facecolor="white")

mpl.rcParams["pdf.fonttype"] = 42
mpl.rcParams["ps.fonttype"] = 42
mpl.rcParams["svg.fonttype"] = "none"

In [None]:
import importlib.metadata as md

pkgs = [
    "python",
    "numpy",
    "pandas",
    "scipy",
    "anndata",
    "scanpy",
    "matplotlib",
    "squidpy",
]
for p in pkgs:
    try:
        if p == "python":
            import sys
            print(f"python=={sys.version.split()[0]}")
        else:
            print(f"{p}=={md.version(p)}")
    except md.PackageNotFoundError:
        print(f"{p} not installed")

## Data loading

We load two `AnnData` objects for **one** representative sample:

- `*_cellfrac.h5ad`: spot × cell-type (or feature) matrix (e.g., deconvolution fractions)
- `*_gene.h5ad`: spot × gene expression matrix

**Required fields**
- `adata.obs["subtype"]`: niche subtype label per spot (categorical)
- `adata.obsm["spatial"]`: spatial coordinates for plotting


In [None]:
from pathlib import Path
import re

sample = CONFIG["sample_id"]
subtype_key = CONFIG["subtype_key"]

cellfrac_path = CONFIG["data_dir"] / CONFIG["cellfrac_file"].format(sample=sample)
gene_path     = CONFIG["data_dir"] / CONFIG["gene_file"].format(sample=sample)

assert cellfrac_path.exists(), f"Missing input: {cellfrac_path}"
assert gene_path.exists(),     f"Missing input: {gene_path}"

adata_frac = sc.read_h5ad(cellfrac_path)
adata_gene = sc.read_h5ad(gene_path)

for name, ad in [("cellfrac", adata_frac), ("gene", adata_gene)]:
    assert subtype_key in ad.obs.columns, f"[{name}] missing adata.obs['{subtype_key}']"
    assert "spatial" in ad.obsm_keys(),   f"[{name}] missing adata.obsm['spatial'] (needed for spatial plots)"
print(f"Loaded: adata_frac={adata_frac.shape}, adata_gene={adata_gene.shape}")

adata_frac.uns.setdefault("sample", sample)
adata_gene.uns.setdefault("sample", sample)

CUSTOM_SUBTYPE_ORDER = [
    "Malignant",
    "Non-malignant cell",
    "Malignant EMT/ECM",
    "Malignant EMT/ECM/Prolif",
    "Stromal/CAF",
    "Stromal/Immune",
    "Normal/Immune",
    "Endo/myCAF",
    "Endo",
    "Tumor/Immune",
    "Malignant/Macro",
]

def set_ordered_subtype(ad, key=subtype_key, order=CUSTOM_SUBTYPE_ORDER):
    present = ad.obs[key].astype(str).unique().tolist()
    order_in_use = [x for x in order if x in present]
    if not order_in_use:
        order_in_use = sorted(present)
    ad.obs[key] = pd.Categorical(ad.obs[key].astype(str), categories=order_in_use, ordered=True)
    return order_in_use

order_in_use = set_ordered_subtype(adata_frac)
_ = set_ordered_subtype(adata_gene)

latex_pat = re.compile(r"(.+)\$\^\{\+\}\$(.+)")

def sanitize_cellfrac_varnames(ad):
    old = list(ad.var_names)
    new = []
    for n in old:
        m = latex_pat.match(n)
        if m:
            gene = m.group(1).strip()
            cell = m.group(2).strip()
            new.append(f"{cell}_{gene}")
        else:
            new.append(n)
    ad.var_names = new
    ad.var_names_make_unique()

sanitize_cellfrac_varnames(adata_frac)

def ensure_umap(ad):
    if "X_umap" in ad.obsm_keys():
        return
    sc.tl.pca(ad, n_comps=30, svd_solver="arpack", random_state=CONFIG["seed"])
    sc.pp.neighbors(ad, n_neighbors=CONFIG["umap_n_neighbors"], n_pcs=30)
    sc.tl.umap(ad, min_dist=CONFIG["umap_min_dist"], random_state=CONFIG["seed"])

ensure_umap(adata_frac)
ensure_umap(adata_gene)

## Core analysis

1) Visualize niche subtypes on **UMAP** and **tissue coordinates**.  
2) Summarize subtype-associated **cell-type composition** (from `*_cellfrac.h5ad`).  
3) Compute subtype-associated **marker genes** (from `*_gene.h5ad`) and export a DEG table.


In [None]:
def save_umap(adata, out_path, color=subtype_key):
    sc.pl.umap(
        adata,
        color=color,
        frameon=False,
        show=False,
        legend_loc="right margin",
    )
    plt.savefig(out_path, bbox_inches="tight")
    plt.close()

def save_spatial(adata, out_path, color=subtype_key, spot_size=7.5):
    try:
        sc.pl.spatial(
            adata,
            color=color,
            img_key="lowres",
            spot_size=spot_size,
            frameon=False,
            legend_loc="right margin",
            show=False,
        )
        plt.savefig(out_path, bbox_inches="tight")
        plt.close()
    except Exception:
        coords = adata.obsm["spatial"]
        cats = adata.obs[color].astype(str)
        cat_order = list(pd.Categorical(cats).categories)
        cat_to_int = {c: i for i, c in enumerate(cat_order)}
        c_int = cats.map(cat_to_int).astype(int).values

        plt.figure(figsize=(6, 6))
        plt.scatter(coords[:, 0], coords[:, 1], c=c_int, s=spot_size, cmap="tab20", linewidths=0)
        plt.gca().invert_yaxis()
        plt.axis("off")
        plt.title(f"{adata.uns.get('sample','sample')} — {color}")
        plt.savefig(out_path, bbox_inches="tight")
        plt.close()

save_umap(adata_frac, FIGDIR / f"{sample}_umap_subtype.pdf")
save_spatial(adata_frac, FIGDIR / f"{sample}_spatial_subtype.pdf", spot_size=7.5)

In [None]:
CELLFRAC_FEATURES = [
    "Malignant_C1", "Malignant_C0", "Malignant_C4", "Malignant_C2",
    "Non_malignant",
    "Mono_THBS1",
    "Macro_CXCL9", "Macro_SPP1", "Macro_FOLR2_C1Q",
    "ecmFibro_FAP", "ecmFibro_ASPN",
    "Pericyte",
    "Endo_SEMA3G", "Endo_TFF3",
    "Treg", "CD4T_Naive", "CD8T_Exhausted", "NK_GNLY",
    "cDC1", "cDC2",
]

CELLFRAC_FEATURES = [x for x in CELLFRAC_FEATURES if x in adata_frac.var_names]
assert len(CELLFRAC_FEATURES) > 0, "None of the requested cell-fraction features were found in adata_frac.var_names."

sc.pl.matrixplot(
    adata_frac,
    var_names=CELLFRAC_FEATURES,
    groupby=subtype_key,
    dendrogram=False,
    cmap="viridis",
    standard_scale="var",
    colorbar_title="column-scaled\nfraction",
    show=False,
)
plt.savefig(FIGDIR / f"{sample}_matrixplot_cellfrac.pdf", bbox_inches="tight")
plt.close()

def top2_by_groupmean(ad, features, group_key=subtype_key):
    X = ad.X.A if hasattr(ad.X, "A") else np.asarray(ad.X)
    expr = pd.DataFrame(X, index=ad.obs_names, columns=ad.var_names)

    mean_mat = expr.groupby(ad.obs[group_key], observed=False).mean().T
    z = (mean_mat.sub(mean_mat.mean(axis=1), axis=0)
                  .div(mean_mat.std(axis=1), axis=0))

    def top_two(row):
        top = row.nlargest(2)
        return pd.Series(
            {
                "Top1_Subtype": top.index[0],
                "Top1_Z": float(top.iat[0]),
                "Top2_Subtype": top.index[1],
                "Top2_Z": float(top.iat[1]),
            }
        )

    return z.loc[features].apply(top_two, axis=1)

cellfrac_top2 = top2_by_groupmean(adata_frac, CELLFRAC_FEATURES)
cellfrac_top2.to_csv(TABDIR / f"{sample}_cellfrac_feature_top2_subtypes.csv")

### Spatial association statistics (global + spatial lag) & HiComb

This section reproduces the missing part mentioned in the article-brief:

- **Global correlation**: spot-wise Spearman correlation between selected deconvolution components.
- **Spatial-lag correlation**: Spearman correlation between **X** at a spot and the **neighbor-averaged Y** (spatial lag).
- **HiComb**: threshold-based “high-combination” states (All3 / pairwise) for the same 3 components, used to summarize co-localization.

> Note: edit the `TARGET3` candidates below if your `adata_frac.var_names` use different names.


In [None]:
# ===== Spatial association statistics (global Spearman + spatial-lag Spearman) & HiComb =====
import numpy as np
import pandas as pd
import scipy.sparse as sp
from scipy.stats import spearmanr

# ---------------------------
# 0) Configuration (edit if needed)
# ---------------------------
# Pick 3 components for colocalization analysis (Malignant / Mono / Fibro as in the original script)
TARGET3 = {
    "Malignant": ["Malig_EMTinf"],
    "Mono":      [ "Mono_THBS1"],
    "Fibro":     ["ecmFibro_FAP"],
}

# If your *_cellfrac.h5ad stores abundance rather than fraction, set this to True to convert to per-spot fractions.
USE_ROW_NORMALIZED = False

# Spatial neighborhood definition (Visium default: hex grid)
COORD_TYPE = "grid"
N_NEIGHS   = 6
N_RINGS    = 2   # 2 rings → ~18 neighbors on hex grid

# HiComb threshold (quantile)
HICOMB_Q = 0.80

# Permutations for (optional) neighborhood enrichment / co-occurrence if squidpy is available
N_PERMS_NHOOD = 1000


# ---------------------------
# 1) Helper functions
# ---------------------------
def _pick_first_present(cands, available):
    for x in cands:
        if x in available:
            return x
    return None

def _get_vec(ad, var):
    x = ad[:, var].X
    if sp.issparse(x):
        x = x.toarray()
    return np.asarray(x).ravel()

def get_spatial_W(ad, n_rings=2, n_neighs=6, coord_type="grid", recompute=False):
    # Return a sparse neighbor graph W (n_obs x n_obs).
    # - Prefer squidpy if available (Visium hex grid).
    # - Fallback: kNN graph built from ad.obsm['spatial'].
    key = "spatial_connectivities"
    if (not recompute) and (key in ad.obsp):
        W = ad.obsp[key]
        return W.tocsr() if sp.issparse(W) else sp.csr_matrix(W)

    # 1) squidpy (preferred)
    try:
        import squidpy as sq
        sq.gr.spatial_neighbors(ad, coord_type=coord_type, n_neighs=n_neighs, n_rings=n_rings)
        W = ad.obsp[key]
        return W.tocsr() if sp.issparse(W) else sp.csr_matrix(W)
    except Exception as e:
        print(f"[NOTE] squidpy spatial_neighbors unavailable/failed ({type(e).__name__}); using kNN fallback.")

    # 2) fallback: kNN (~3*r*(r+1) neighbors within r rings on hex grid; r=2 -> 18)
    coords = np.asarray(ad.obsm["spatial"])
    n = coords.shape[0]
    k = min(n - 1, max(1, 3 * n_rings * (n_rings + 1)))

    try:
        from sklearn.neighbors import NearestNeighbors
        nn = NearestNeighbors(n_neighbors=k + 1)
        nn.fit(coords)
        _, idx = nn.kneighbors(coords)
    except Exception:
        from scipy.spatial import cKDTree
        tree = cKDTree(coords)
        _, idx = tree.query(coords, k=k + 1)

    rows = np.repeat(np.arange(n), k)
    cols = idx[:, 1:].reshape(-1)   # drop self
    data = np.ones_like(cols, dtype=float)

    W = sp.csr_matrix((data, (rows, cols)), shape=(n, n))
    W = W.maximum(W.T)             # make undirected
    ad.obsp[key] = W
    return W

def spatial_lag(v, W):
    # Row-normalized spatial lag: (W v) / rowSums(W).
    wsum = np.asarray(W.sum(axis=1)).ravel()
    wsum[wsum == 0] = 1.0
    return W.dot(v) / wsum

def pairwise_spearman_p(X_df):
    feats = list(X_df.columns)
    n = len(feats)
    P = pd.DataFrame(np.nan, index=feats, columns=feats, dtype=float)
    for i in range(n):
        for j in range(i + 1, n):
            p = spearmanr(X_df.iloc[:, i].values, X_df.iloc[:, j].values).pvalue
            P.iat[i, j] = P.iat[j, i] = float(p)
    return P

def plot_heatmap(mat, labels, title, out_pdf, vmin=-1, vmax=1):
    import matplotlib.pyplot as plt
    fig, ax = plt.subplots(figsize=(3.8, 3.2))
    im = ax.imshow(mat, vmin=vmin, vmax=vmax, cmap="RdBu_r")
    ax.set_xticks(range(len(labels))); ax.set_yticks(range(len(labels)))
    ax.set_xticklabels(labels, rotation=35, ha="right", fontsize=8)
    ax.set_yticklabels(labels, fontsize=8)
    for i in range(len(labels)):
        for j in range(len(labels)):
            if np.isfinite(mat[i, j]):
                ax.text(j, i, f"{mat[i, j]:.2f}", ha="center", va="center", fontsize=8)
    fig.colorbar(im, ax=ax, shrink=0.85, label="Spearman ρ")
    ax.set_title(title)
    fig.tight_layout()
    fig.savefig(out_pdf, bbox_inches="tight")
    plt.close(fig)


# ---------------------------
# 2) Resolve 3 target features
# ---------------------------
available = list(adata_frac.var_names)
picked = {k: _pick_first_present(v, available) for k, v in TARGET3.items()}

if any(v is None for v in picked.values()):
    print("[ERROR] Cannot find all 3 TARGET3 features in adata_frac.var_names.")
    for k, cands in TARGET3.items():
        print(f" - {k:9s}: candidates={cands} ; picked={picked[k]}")
    print("\nTip: search your feature names, e.g.\n  [x for x in adata_frac.var_names if 'Mono' in x][:20]")
    raise KeyError("Please edit TARGET3 to match your adata_frac.var_names.")

feat_malig = picked["Malignant"]
feat_mono  = picked["Mono"]
feat_fibro = picked["Fibro"]
features3 = [feat_malig, feat_mono, feat_fibro]
print("[coloc3] Using:", features3)


# ---------------------------
# 3) Extract vectors (optionally row-normalize)
# ---------------------------
X3 = pd.DataFrame({f: _get_vec(adata_frac, f) for f in features3}, index=adata_frac.obs_names)

if USE_ROW_NORMALIZED:
    total = adata_frac.X.sum(axis=1)
    if sp.issparse(total):
        total = total.A1
    else:
        total = np.asarray(total).ravel()
    total[total == 0] = 1.0
    X3 = X3.div(total, axis=0)
    print("[coloc3] Row-normalized fractions enabled (USE_ROW_NORMALIZED=True).")


# ---------------------------
# 4) Global Spearman correlation (spot-wise)
# ---------------------------
global_rho = X3.corr(method="spearman")
global_p   = pairwise_spearman_p(X3)

print("\n[Global correlation] Spearman ρ (spot-wise):")
print(global_rho.round(3))

# save
out_prefix = f"{sample}.coloc3"
global_rho.to_csv(TABDIR / f"{out_prefix}.global_spearman_rho.csv")
global_p.to_csv(TABDIR / f"{out_prefix}.global_spearman_p.csv")
plot_heatmap(global_rho.values, features3, "Global Spearman (spot-wise)", FIGDIR / f"{out_prefix}.global_spearman_rho.pdf")


# ---------------------------
# 5) Spatial-lag Spearman correlation (directional)
#    entry (i, j) = spearman( X_i , lag(X_j) )
# ---------------------------
W = get_spatial_W(adata_frac, n_rings=N_RINGS, n_neighs=N_NEIGHS, coord_type=COORD_TYPE, recompute=False)

lag_rho = pd.DataFrame(np.nan, index=features3, columns=features3, dtype=float)
lag_p   = pd.DataFrame(np.nan, index=features3, columns=features3, dtype=float)

for a in features3:
    for b in features3:
        if a == b:
            continue
        rho, p = spearmanr(X3[a].values, spatial_lag(X3[b].values, W))
        lag_rho.loc[a, b] = float(rho)
        lag_p.loc[a, b]   = float(p)

print("\n[Spatial-lag correlation] Spearman ρ  (i vs neighbors(j)):")
print(lag_rho.round(3))

lag_rho.to_csv(TABDIR / f"{out_prefix}.lag_spearman_rho.csv")
lag_p.to_csv(TABDIR / f"{out_prefix}.lag_spearman_p.csv")
plot_heatmap(lag_rho.values, features3, "Lag Spearman (i → neighbors(j))", FIGDIR / f"{out_prefix}.lag_spearman_rho.pdf")


# ---------------------------
# 6) HiComb (core): define All3 / pairwise high-combination states
# ---------------------------
q = HICOMB_Q
thr = {f: float(np.quantile(X3[f].values, q)) for f in features3}
hi  = {f: (X3[f].values >= thr[f]) for f in features3}

labels = np.full(adata_frac.n_obs, "Other", dtype=object)
labels[hi[feat_malig] & hi[feat_mono] & hi[feat_fibro]] = "High_All3"
labels[hi[feat_fibro] & hi[feat_malig] & ~hi[feat_mono]] = "High_FAP&Malig"
labels[hi[feat_fibro] & ~hi[feat_malig] & hi[feat_mono]] = "High_FAP&Mono"
labels[~hi[feat_fibro] & hi[feat_malig] & hi[feat_mono]] = "High_Malig&Mono"

hicomb_order = ["High_All3", "High_FAP&Malig", "High_FAP&Mono", "High_Malig&Mono", "Other"]
adata_frac.obs["HiComb"] = pd.Categorical(labels, categories=hicomb_order, ordered=True)

hicomb_counts = adata_frac.obs["HiComb"].value_counts().reindex(hicomb_order, fill_value=0)
print("\n[HiComb] q=%.2f thresholds:" % q, thr)
print("[HiComb] counts:")
print(hicomb_counts.to_string())

# save
pd.Series(thr).to_csv(TABDIR / f"{out_prefix}.HiComb_thresholds_q{int(q*100)}.csv")
hicomb_counts.to_csv(TABDIR / f"{out_prefix}.HiComb_counts.csv")

# save spatial plot (uses the earlier helper; falls back to scatter if no tissue image)
try:
    save_spatial(adata_frac, FIGDIR / f"{sample}_spatial_HiComb.pdf", color="HiComb", spot_size=7.5)
except Exception as e:
    print("[NOTE] HiComb spatial plot skipped:", type(e).__name__, e)

# attach lightweight results to AnnData (avoid non-serializable objects in .uns)
adata_frac.uns["__coloc3__"] = {
    "features3": features3,
    "global_spearman_rho": global_rho.values,
    "global_spearman_p": global_p.values,
    "lag_spearman_rho": lag_rho.values,
    "lag_spearman_p": lag_p.values,
    "hicomb_key": "HiComb",
    "hicomb_q": float(q),
    "hicomb_thresholds": thr,
    "hicomb_counts": hicomb_counts.to_dict(),
    "spatial_graph": {"coord_type": COORD_TYPE, "n_neighs": int(N_NEIGHS), "n_rings": int(N_RINGS)},
}

# ---------------------------
# 7) Optional: HiComb neighborhood enrichment / co-occurrence (requires squidpy)
# ---------------------------
try:
    import squidpy as sq

    sq.gr.nhood_enrichment(adata_frac, cluster_key="HiComb", n_perms=N_PERMS_NHOOD, seed=CONFIG["seed"])
    z = adata_frac.uns["HiComb_nhood_enrichment"]["zscore"]
    z = pd.DataFrame(z, index=hicomb_order, columns=hicomb_order)
    z.to_csv(TABDIR / f"{out_prefix}.HiComb_nhood_enrichment_z.csv")
    print("[HiComb] Saved neighborhood enrichment z-score.")

    sq.gr.co_occurrence(adata_frac, cluster_key="HiComb")
    co = adata_frac.uns["HiComb_co_occurrence"]
    np.savez(
        TABDIR / f"{out_prefix}.HiComb_co_occurrence.npz",
        occ=np.asarray(co["occ"]),
        interval=np.asarray(co["interval"]),
        categories=np.array(hicomb_order, dtype=object),
    )
    print("[HiComb] Saved co-occurrence (npz).")

except Exception as e:
    print("[NOTE] squidpy optional steps skipped:", type(e).__name__, e)


In [None]:
ad_de = adata_gene.copy()

sc.pp.normalize_total(ad_de, target_sum=1e4)
sc.pp.log1p(ad_de)
sc.pp.highly_variable_genes(
    ad_de,
    min_mean=CONFIG["hvg_min_mean"],
    max_mean=CONFIG["hvg_max_mean"],
    min_disp=CONFIG["hvg_min_disp"],
)
ad_de.raw = ad_de
ad_de = ad_de[:, ad_de.var["highly_variable"]].copy()

sc.tl.rank_genes_groups(ad_de, groupby=subtype_key, method=CONFIG["rank_method"])

def _rank_genes_groups_to_df(ad, key="rank_genes_groups"):
    rg = ad.uns[key]

    if hasattr(rg["names"], "dtype") and getattr(rg["names"].dtype, "names", None) is not None:
        groups = list(rg["names"].dtype.names)
        names_get = lambda g: rg["names"][g]
        vec_get = lambda mat, g: (mat[g] if mat is not None else None)
    else:
        groups = list(rg["names"].columns)
        names_get = lambda g: rg["names"][g].values
        vec_get = lambda mat, g: (mat[g].values if mat is not None else None)

    out = []
    for g in groups:
        names = names_get(g)
        scores = vec_get(rg.get("scores", None), g)
        pvals = vec_get(rg.get("pvals", None), g)
        pvals_adj = vec_get(rg.get("pvals_adj", None), g)
        lfc = vec_get(rg.get("logfoldchanges", None), g)

        n = len(names)
        for i in range(n):
            out.append(
                {
                    "group": g,
                    "names": names[i],
                    "scores": float(scores[i]) if scores is not None else None,
                    "logfoldchanges": float(lfc[i]) if lfc is not None else None,
                    "pvals": float(pvals[i]) if pvals is not None else None,
                    "pvals_adj": float(pvals_adj[i]) if pvals_adj is not None else None,
                }
            )
    return pd.DataFrame(out)

try:
    deg_df = sc.get.rank_genes_groups_df(ad_de, group=None)
except Exception:
    deg_df = _rank_genes_groups_to_df(ad_de)

deg_df.to_csv(TABDIR / f"{sample}_rank_genes_groups_{subtype_key}.csv", index=False)

sc.pl.rank_genes_groups_dotplot(
    ad_de,
    groupby=subtype_key,
    n_genes=10,
    standard_scale="var",
    show=False,
)
plt.savefig(FIGDIR / f"{sample}_deg_dotplot_top10.pdf", bbox_inches="tight")
plt.close()

In [None]:
GENE_MARKERS = [
    "SAA1", "HSPA5", "GPC3",
    "ALB", "TF", "PCK1",
    "TSPAN8", "CCL20", "CD44","FN1", "LGALS3", "TIMP1", "AFP", "LTBR", "RELN",
    "VIM", "COL1A1", "FAP",
    "CD74", "HLA-DRB1", "CD52",
    "VWF", "PECAM1",
]
GENE_MARKERS = [g for g in GENE_MARKERS if g in adata_gene.var_names]

if len(GENE_MARKERS) == 0:
    print("WARNING: None of the marker genes were found in adata_gene.var_names; skipping matrixplot.")
else:
    sc.pl.matrixplot(
        adata_gene,
        var_names=GENE_MARKERS,
        groupby=subtype_key,
        dendrogram=False,
        cmap="viridis",
        standard_scale="var",
        colorbar_title="column-scaled\nexpression",
        show=False,
    )
    plt.savefig(FIGDIR / f"{sample}_matrixplot_gene_markers.pdf", bbox_inches="tight")
    plt.close()

## Results & exports

All outputs are written under:

- `./outputs/st_subtyping_demo/{sample_id}/`

Key files:
- Figures (PDF): `figs/`
  - `{sample_id}_umap_subtype.pdf`
  - `{sample_id}_spatial_subtype.pdf`
  - `{sample_id}_matrixplot_cellfrac.pdf`
  - `{sample_id}_deg_dotplot_top10.pdf`
  - `{sample_id}_matrixplot_gene_markers.pdf` (if marker genes exist)
- Tables (CSV): `tables/`
  - `{sample_id}_cellfrac_feature_top2_subtypes.csv`
  - `{sample_id}_rank_genes_groups_subtype.csv`


In [None]:
adata_frac.write_h5ad(OUTDIR / f"{sample}_cellfrac.cleaned.h5ad")

adata_gene.write_h5ad(OUTDIR / f"{sample}_gene.cleaned.h5ad")

ad_de.write_h5ad(OUTDIR / f"{sample}_gene.de.h5ad")

print("Wrote outputs to:")
print(f"  {OUTDIR}")
print(f"  - figs   : {FIGDIR}")
print(f"  - tables : {TABDIR}")