# Functions

In [None]:
import matplotlib
matplotlib.rcParams['svg.fonttype'] = 'none'

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.patches as mpatches
import pathlib as pl
import scanpy as sc
import pandas as pd
import numpy as np
from tqdm import tqdm

def group_small_clusters(
    df: pd.DataFrame,
    cluster_col: str,
    min_count: int = 1000,
    new_label: str = "small_clusters",
    output_col: str = None
) -> pd.DataFrame:
    """
    Groups small clusters in a DataFrame column into a single label.

    Parameters:
        df (pd.DataFrame): Input DataFrame containing cluster labels.
        cluster_col (str): Name of the column containing cluster labels (e.g., 'leiden').
        min_count (int): Minimum number of entries a cluster must have to avoid grouping.
        new_label (str): Label to assign to small clusters.
        output_col (str or None): Name of the new column to store grouped labels. 
                                  If None, defaults to '{cluster_col}_grouped'.

    Returns:
        pd.DataFrame: A copy of the DataFrame with a new column containing grouped cluster labels.
    """
    if cluster_col not in df.columns:
        raise ValueError(f"Column '{cluster_col}' not found in DataFrame.")

    output_col = output_col or f"{cluster_col}_grouped"
    cluster_counts = df[cluster_col].value_counts()
    small_clusters = cluster_counts[cluster_counts < min_count].index

    new_df = df.copy()
    new_df[output_col] = df[cluster_col].astype(str)
    new_df.loc[df[cluster_col].isin(small_clusters), output_col] = new_label

    return new_df[output_col]


In [None]:
def build_palettes_from_adata(adata, palette_specs):
    """
    Build labeled color palettes for categorical columns in adata.obs.

    Parameters
    ----------
    adata : AnnData
        Must have .obs DataFrame containing categorical columns.
    palette_specs : dict
        Mapping {column_name: palette} where palette can be:
          - a string palette name (e.g. "tab10")
          - a list of RGB colors (custom)

    Returns
    -------
    dict
        {column_name: {label: color}} mapping.
    """
    custom_palettes = {}

    for col, palette in palette_specs.items():
        if col not in adata.obs.columns:
            print(f"⚠️ Warning: '{col}' not found in adata.obs — skipping.")
            continue

        unique_vals = sorted(adata.obs[col].astype(str).dropna().unique())
        n_unique = len(unique_vals)

        # If user passed a name → generate via seaborn
        if isinstance(palette, str):
            pal_colors = sns.color_palette(palette, n_colors=n_unique)
        # If user passed a list → use directly
        elif isinstance(palette, (list, tuple)):
            pal_colors = palette[:n_unique]
        else:
            raise ValueError(f"Unsupported palette type for '{col}': {type(palette)}")

        color_dict = dict(zip(unique_vals, pal_colors))
        custom_palettes[col] = color_dict

    print(f"✅ Built palettes for {len(custom_palettes)} columns.")
    return custom_palettes


def plot_celltype_spatial_single_split_legend(
    df,
    color_by="celltype",
    sample_id=None,
    title=None,
    palette_dict=None,         # ✅ added
    palette_name="tab20",
    s=1.5,
    save_svg=True,
    output_prefix="spatial_plot",
    legend_title=None,
):
    """
    Nature Genetics–style spatial scatterplot for one sample,
    saving main plot as PNG (raster) and legend separately as SVG (vector).
    """
    sns.set_style("white")
    sns.set_context("talk")

    # --- Subset one sample ---
    if sample_id is not None:
        df = df[df["sample_id"] == sample_id].copy()
        if df.empty:
            raise ValueError(f"Sample ID '{sample_id}' not found in DataFrame.")

    # --- Colors ---
    unique_labels = sorted(df[color_by].dropna().unique())
    if palette_dict is not None and color_by in palette_dict:
        print('Using provided color palette.')
        color_dict = palette_dict[color_by]
    else:
        print('Generating color palette.')
        palette = sns.color_palette(palette_name, n_colors=len(unique_labels))
        color_dict = dict(zip(unique_labels, palette))

    # --- Main plot ---
    fig, ax = plt.subplots(figsize=(6, 5), dpi=300)
    sns.scatterplot(
        data=df,
        x="X_coord", y="Y_coord",
        hue=color_by, palette=color_dict,
        s=s, alpha=0.9, linewidth=0,
        rasterized=True, ax=ax, legend=False
    )
    ax.invert_yaxis(); ax.set_aspect("equal", adjustable="box")
    for spine in ["top", "right", "left", "bottom"]:
        ax.spines[spine].set_visible(False)
    ax.set_xticks([]); ax.set_yticks([])
    ax.set_xlabel(""); ax.set_ylabel("")
    plt.tight_layout()

    # --- Save main figure ---
    fname_main = f"{output_prefix}_{sample_id or 'sample'}_main.png"
    fig.savefig(fname_main, dpi=300, bbox_inches="tight", transparent=True, format="png")
    print(f"Saved main figure: {fname_main}")

    # --- Legend ---
    fig_leg, ax_leg = plt.subplots(figsize=(3, 0.5 * len(unique_labels)), dpi=300)
    handles = [
        plt.Line2D([0], [0], marker='o', color='none', label=label,
                   markerfacecolor=color_dict[label], markersize=8)
        for label in unique_labels
    ]
    ax_leg.legend(handles=handles, loc="center left", frameon=False,
                  title=legend_title or color_by, title_fontsize=14, fontsize=14)
    ax_leg.axis("off")
    plt.tight_layout()

    if save_svg:
        fname_leg = f"{output_prefix}_{sample_id or 'sample'}_legend.svg"
        fig_leg.savefig(fname_leg, dpi=300, bbox_inches="tight", transparent=True, format="svg")
        print(f"Saved legend: {fname_leg}")

    plt.close(fig); plt.close(fig_leg)


# Data

In [None]:
import pathlib as pl
import pandas as pd
import scanpy as sc
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

In [None]:
sample_list = ['P1CRC','P2CRC','P5CRC','P3NAT','P5NAT']
base_dir = pl.Path('../../../Broad_SpatialFoundation/VisiumHD-CRC/')

In [None]:
all_embeddings = pd.read_parquet('../../../Broad_SpatialFoundation/VisiumHD-CRC/SpatialFusion_full_emb.parquet')

In [None]:
adatas = []
for sample in tqdm(sample_list):
    adata = sc.read_h5ad(base_dir / sample / 'adata.h5ad')
    adata.obs_names = adata.obs_names + '::' + sample 
    adata.obs['sample_id'] = sample
    
    embeddings_df = all_embeddings.loc[all_embeddings.index.intersection(adata.obs_names)].copy()
    
    common_idx = adata.obs_names.intersection(embeddings_df.index)
    adata = adata[common_idx].copy()
    embeddings_df = embeddings_df.loc[common_idx]
    print(adata.shape, embeddings_df.shape)
    adata.obsm['NicheFinder'] = embeddings_df.loc[:,['0','1','2','3','4','5','6','7','8','9']] 
    
    adatas.append(adata)
adata = adatas[0].concatenate(*adatas[1:])
adata.obs_names = adata.obs_names.str.split('-').str[0]

In [None]:
adata_obs = pd.read_csv('full_CRC_obs.csv', index_col=0)

In [None]:
adata = adata[adata_obs.index].copy()

In [None]:
adata.obs = adata_obs.loc[adata.obs_names]

In [None]:
adata = adata[adata.obs.refined_celltypes!='Noise'].copy()

In [None]:
common_idx = adata.obs_names.intersection(all_embeddings.index)

In [None]:
adata = adata[common_idx].copy()

In [None]:
all_embeddings = all_embeddings.loc[common_idx]

In [None]:
all_embeddings['cellsubtypes'] = adata.obs.loc[all_embeddings.index, 'refined_cellsubtypes'].ravel()

In [None]:
all_embeddings['celltypes'] = adata.obs.loc[all_embeddings.index, 'refined_celltypes'].ravel()

In [None]:
all_embeddings['leiden_joint'] = group_small_clusters(
    adata.obs[['leiden']],
    cluster_col='leiden',
    min_count= 500,
    new_label= "Other",
    output_col = None
)

In [None]:
adata.obs['leiden_joint'] = group_small_clusters(
    adata.obs[['leiden']],
    cluster_col='leiden',
    min_count= 500,
    new_label= "Other",
    output_col = None
)

In [None]:
adata.obs.leiden = adata.obs.leiden.astype(str).astype('category')

In [None]:
tab_filtered = sns.color_palette()
tab_filtered = [c for i,c in enumerate(tab_filtered) if i not in [4,6]]

tab20_filtered = sns.color_palette('tab20') + sns.color_palette('tab20c')[:17]
tab20_filtered = [c for i,c in enumerate(tab20_filtered) if i not in [8,9,12,13]]

In [None]:
palette_specs = {
            'leiden_joint': tab20_filtered,
            'cellsubtypes': tab20_filtered,
            'celltypes': tab_filtered,
        }

palette_dict_1 = build_palettes_from_adata(adata, palette_specs)

In [None]:
adata.obs['X_coord'] = adata.obsm['spatial'][:,0]
adata.obs['Y_coord'] = adata.obsm['spatial'][:,1]

In [None]:
for sample in tqdm(adata.obs.sample_id.unique()):
    print(sample)
    tmp = adata[adata.obs.sample_id==sample].copy()
    plot_celltype_spatial_single_split_legend(
        tmp.obs,
        color_by="celltypes",
        sample_id=sample,
        title=None,
        palette_dict=palette_dict_1,         # ✅ added
        s=1.5,
        save_svg=True,
        output_prefix=f"../../../SpatialFusion/results/figures_Fig5/celltypes",
        legend_title=None,
    )

In [None]:
for sample in tqdm(adata.obs.sample_id.unique()):
    print(sample)
    tmp = adata[adata.obs.sample_id==sample].copy()
    plot_celltype_spatial_single_split_legend(
        tmp.obs,
        color_by="cellsubtypes",
        sample_id=sample,
        title=None,
        palette_dict=palette_dict_1,         # ✅ added
        s=1,
        save_svg=True,
        output_prefix=f"../../../SpatialFusion/results/figures_Fig5/cellsubtypes",
        legend_title=None,
    )

In [None]:
for sample in tqdm(adata.obs.sample_id.unique()):
    print(sample)
    tmp = adata[adata.obs.sample_id==sample].copy()
    plot_celltype_spatial_single_split_legend(
        tmp.obs,
        color_by="leiden_joint",
        sample_id=sample,
        title=None,
        palette_dict=palette_dict_1,         # ✅ added
        s=1.5,
        save_svg=True,
        output_prefix=f"../../../SpatialFusion/results/figures_Fig5/niches",
        legend_title='Niche label',
    )

In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap

# --- Data prep (your same logic) ---
cl_order = np.sort(np.setdiff1d(all_embeddings.leiden_joint.unique(), ['Other', 'nan']).astype(int)).astype(str)
sample_order = ['P1CRC','P2CRC','P5CRC','P3NAT','P5NAT']

vc = all_embeddings[['sample_id','leiden_joint']].value_counts().unstack().T
vc = vc / vc.sum(axis=0)
vc = vc.loc[cl_order, sample_order]

# --- Style setup (Nature Genetics–like aesthetic) ---
sns.set_theme(context="talk", style="white")

# --- Custom colormap: subtle, elegant red-to-gray gradient ---
cmap = LinearSegmentedColormap.from_list(
    "vlag_redgray",
    ["#f7f7f7", "#f4a3a8", "#b40426"]
)

# --- Prepare annotation matrix ---
annot = vc.copy() * 100  # convert to percent
annot_fmt = annot.copy()

# format as strings with rules:
for i in range(annot_fmt.shape[0]):
    for j in range(annot_fmt.shape[1]):
        val = annot_fmt.iat[i, j]
        if pd.isna(val):
            annot_fmt.iat[i, j] = "N.A."
        elif val < 1:
            annot_fmt.iat[i, j] = "<1%"
        else:
            annot_fmt.iat[i, j] = f"{val:.0f}%"  # round to nearest percent

# --- Plot ---
fig, ax = plt.subplots(figsize=(2, 3.5), dpi=300)

sns.heatmap(
    vc,
    cmap=cmap,
    annot=annot_fmt,
    fmt="",
    linewidths=0.4,
    linecolor="white",
    cbar=False,
    annot_kws={"fontsize": 7, "color": "black"},
    ax=ax,
)

# --- Aesthetic adjustments ---
ax.set_xlabel("", fontsize=11)
ax.set_ylabel("", fontsize=11)
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right", fontsize=7)
ax.set_yticks(np.arange(len(vc.index)) + 0.5)
ax.set_yticklabels(vc.index, rotation=0, fontsize=7)
ax.tick_params(length=0)

# Remove borders and extra gridlines
for spine in ax.spines.values():
    spine.set_visible(False)

# Optional title
ax.set_title("", fontsize=12, pad=10, fontweight="normal")

plt.tight_layout()
fig.savefig('../../../SpatialFusion/results/figures_Fig5/niches_proportions.svg')
plt.show()


# Plot cell type composition

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Optional, Sequence, Mapping, Tuple, Dict, List
from matplotlib.colors import to_hex, to_rgb
import colorsys


def plot_cluster_composition_stacked(
    df,
    cluster_key: str = "leiden",
    type_key: str = "celltypes",                 # or "cellsubtypes"
    cluster_order: Optional[Sequence[str]] = None,
    strict_order: bool = False,                  # if True, only show clusters in cluster_order
    palette: Optional[Mapping[str, str]] = None, # dict {celltype: color}; auto if None
    top_types: Optional[int] = None,             # keep top N types globally, rest→"Other"
    min_frac: Optional[float] = None,            # keep types with global frac >= min_frac, rest→"Other"
    other_label: str = "Other",
    type_order: Optional[Sequence[str]] = None,  # custom order of stack segments
    figsize=(10, 5),
    percent_labels: bool = False,                # print % on bars
    label_threshold: float = 0.05,               # only label segments >=5%
    savefig: Optional[str] = None,
):
    """
    Plot a 100% stacked barplot of type proportions per cluster using adata.obs.
    Returns a long-form DataFrame with columns: [cluster, type, count, frac, percent]
    """
    obs = df[[cluster_key, type_key]].dropna().copy()
    obs[cluster_key] = obs[cluster_key].astype(str)
    obs[type_key]    = obs[type_key].astype(str)

    # Cross-tab counts (rows=clusters, cols=types)
    ct = pd.crosstab(obs[cluster_key], obs[type_key])

    # Global filtering of rare types (optional)
    keep_cols = ct.columns.tolist()
    if top_types is not None:
        keep_cols = (
            ct.sum(axis=0)
              .sort_values(ascending=False)
              .head(top_types)
              .index.tolist()
        )
    if min_frac is not None:
        global_frac = ct.sum(axis=0) / ct.values.sum()
        keep_cols = sorted(set(keep_cols) | set(global_frac[global_frac >= min_frac].index.tolist()))
    if (top_types is not None) or (min_frac is not None):
        other = ct.drop(columns=keep_cols, errors="ignore").sum(axis=1)
        ct = ct[keep_cols].copy()
        if (other > 0).any():
            ct[other_label] = other
        # make sure "Other" is last
        ct = ct[[c for c in ct.columns if c != other_label] + ([other_label] if other_label in ct.columns else [])]

    # Normalize rows to 1.0 (100%)
    row_sums = ct.sum(axis=1).replace(0, np.nan)
    props = ct.div(row_sums, axis=0).fillna(0.0)

    # Cluster reordering
    if cluster_order is not None:
        cluster_order = [str(c) for c in cluster_order]
        missing = [c for c in cluster_order if c not in props.index]
        if strict_order:
            props = props.reindex(cluster_order).dropna(how="all")
        else:
            extras = [c for c in props.index if c not in cluster_order]
            props = props.reindex(cluster_order + extras)
        if missing:
            print(f"Warning: these clusters from cluster_order were not found and will be skipped: {missing}")
    else:
        props = props.sort_index()

    if props.empty:
        raise ValueError("No clusters to plot after filtering/reordering.")

    # Determine stack (type) order
    types_order = props.columns.tolist()
    if type_order is not None:
        type_order = [t for t in type_order if t in props.columns]
        leftovers = [t for t in props.columns if t not in type_order]
        types_order = type_order + leftovers

    # Build color map
    if palette is None:
        base = sns.color_palette("tab10", n_colors=max(10, len(types_order)))
        colmap = dict(zip(types_order, base[:len(types_order)]))
        if other_label in types_order:
            colmap[other_label] = "#B0B0B0"  # gray for "Other"
    else:
        colmap = {t: palette.get(t, "#BBBBBB") for t in types_order}

    # Plot (stacked bars)
    plt.figure(figsize=figsize)
    bottom = np.zeros(len(props))
    x = np.arange(len(props.index))
    ax = plt.gca()

    for t in types_order:
        vals = props[t].values
        ax.bar(x, vals, bottom=bottom, width=0.9, color=colmap[t], label=t, edgecolor="none")
        bottom += vals

    ax.set_xticks(x)
    ax.set_xticklabels(props.index, rotation=45, ha="right")
    ax.set_ylim(0, 1)
    ax.set_ylabel("Composition (% of cells)")
    ax.set_xlabel(cluster_key)
    ax.set_title(f"{type_key} composition per {cluster_key}")

    if percent_labels:
        for i, cl in enumerate(props.index):
            cum = 0.0
            for t in types_order:
                h = props.loc[cl, t]
                if h >= label_threshold:
                    ax.text(i, cum + h/2, f"{h*100:.0f}%", ha="center", va="center", fontsize=8, color="white")
                cum += h

    ax.legend(bbox_to_anchor=(1.02, 1), loc="upper left", title=type_key)
    plt.tight_layout()

    if savefig:
        plt.savefig(savefig, dpi=200, bbox_inches='tight')
    plt.show()

    # Long-form result for downstream use
    plot_df = (
        props.reset_index()
             .melt(id_vars=cluster_key, var_name="type", value_name="frac")
             .rename(columns={cluster_key: "cluster"})
    )
    plot_df["percent"] = (plot_df["frac"] * 100).round(2)
    counts_long = (
        ct.reset_index()
          .melt(id_vars=cluster_key, var_name="type", value_name="count")
          .rename(columns={cluster_key: "cluster"})
    )
    plot_df = plot_df.merge(counts_long, on=["cluster", "type"], how="left")
    return plot_df

def _lightness_shades(base_color, n, l_low=0.35, l_high=0.85):
    """Generate a sequence of lighter shades from a base color."""
    rgb = to_rgb(base_color)
    h, l, s = colorsys.rgb_to_hls(*rgb)
    ls = np.linspace(l_low, l_high, n)
    shades = [to_hex(colorsys.hls_to_rgb(h, li, s)) for li in ls]
    return shades

def make_hierarchical_palettes(
    df,
    parent_key: str = "celltypes",
    child_key: str  = "cellsubtypes",
    parent_order: Optional[Sequence[str]] = None,
    child_order: str = "alpha",
    base_palette: Optional[Mapping[str, str]] = None,  # legacy param
    parent_palette_dict: Optional[Dict[str, Tuple[float, float, float]]] = None,  # ✅ new
    unknown_parent_color: str = "#9e9e9e",
    shade_lightness: Tuple[float, float] = (0.35, 0.85),
) -> Tuple[Dict[str, str], Dict[str, str], List[str], List[str]]:
    """
    Generate hierarchical palettes with optional parent palette override.

    Parameters
    ----------
    parent_palette_dict : dict, optional
        If provided, should be {parent_label: color (hex or RGB)}.
        These colors are used directly for parent_palette and for shading.

    Returns
    -------
    parent_palette : {parent -> hex}
    child_palette  : {child -> hex}
    parent_order_out : list of parents
    child_order_out  : list of children grouped by parent
    """
    obs = df[[parent_key, child_key]].copy()
    obs[parent_key] = obs[parent_key].astype(str)
    obs[child_key]  = obs[child_key].astype(str)

    # --- Parent order ---
    parents = obs[parent_key].unique().tolist()
    if parent_order is not None:
        parent_order_out = [p for p in parent_order if p in parents] + \
                           [p for p in parents if p not in parent_order]
    else:
        freq = obs[parent_key].value_counts()
        parent_order_out = freq.index.tolist() + [p for p in parents if p not in freq.index]

    # --- Children per parent ---
    children_per_parent = {}
    for p in parents:
        sub = obs.loc[obs[parent_key] == p, child_key]
        if child_order == "freq":
            children = sub.value_counts().index.tolist()
        else:
            children = sorted(sub.unique().tolist())
        children_per_parent[p] = children

    # --- Build parent colors ---
    if parent_palette_dict is not None:
        # ✅ use directly, fallback to gray if missing
        parent_palette = {
            p: to_hex(parent_palette_dict.get(p, unknown_parent_color))
            for p in parent_order_out
        }
    elif base_palette is not None:
        parent_palette = {p: base_palette.get(p, unknown_parent_color) for p in parent_order_out}
    else:
        n_par = len(parent_order_out)
        base = sns.color_palette("tab10" if n_par <= 10 else "hls", n_colors=n_par)
        parent_palette = {p: to_hex(base[i]) for i, p in enumerate(parent_order_out)}

    # --- Build child colors as shades of parent ---
    l_low, l_high = shade_lightness
    child_palette = {}
    child_order_out = []
    for p in parent_order_out:
        base_col = parent_palette.get(p, unknown_parent_color)
        kids = children_per_parent.get(p, [])
        if not kids:
            continue
        shades = _lightness_shades(base_col, len(kids), l_low, l_high)
        for k, col in zip(kids, shades):
            child_palette[k] = col
        child_order_out.extend(kids)

    return parent_palette, child_palette, parent_order_out, child_order_out



In [None]:
# Build coherent palettes & orders
parent_pal, child_pal, parent_order, child_order = make_hierarchical_palettes(
    adata.obs,
    parent_key="celltypes",
    child_key="cellsubtypes",
    parent_palette_dict=palette_dict_1['celltypes'],
    child_order="alpha",                 # or "freq"
    shade_lightness=(0.35, 0.85)
)

# (A) 100% stacked bars by lineage (coarse)
plot_df_types = plot_cluster_composition_stacked(
    adata.obs,
    cluster_key="leiden_joint",
    type_key="celltypes",
    cluster_order=[f"{i}" for i in range(18)],  # your preferred cluster order
    strict_order=False,
    palette=parent_pal,
    percent_labels=True,
    figsize=(10, 4),
    savefig='../../../SpatialFusion/results/figures_Fig5/CRC_major_celltype_stacked_finetuned_barplot.svg',
)

# (B) 100% stacked bars by subtypes (fine), colors are shades within lineage color
plot_df_subtypes = plot_cluster_composition_stacked(
    adata.obs,
    cluster_key="leiden_joint",
    type_key="cellsubtypes",
    cluster_order=[f"{i}" for i in range(18)],
    strict_order=False,
    palette=child_pal,
    # NEW: order subtypes grouped by their parent lineage
    type_order=child_order,              # requires tiny tweak shown above
    percent_labels=True,
    figsize=(10, 6),
    savefig='../../../SpatialFusion/results/figures_Fig5/CRC_minor_celltype_stacked_finetuned_barplot.svg',
)


# Characterize niches

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

def plot_with_highlights(
    df,
    title,
    highlight_clusters,
    cats,
    base_palette,
    highlight_palette,
    *,
    x_col="X_coord",
    y_col="Y_coord",
    cluster_col="leiden_joint",
    s_base=1.2,
    s_highlight=5,
    alpha_base=0.25,
    alpha_highlight=0.9,
    rasterized=True,
    savefig=None,
):
    """
    Nature Genetics–style spatial scatterplot with selected clusters highlighted.

    Parameters
    ----------
    df : pd.DataFrame
        Must contain coordinates and cluster labels.
    title : str
        Title or sample ID.
    highlight_clusters : list
        Clusters to highlight.
    cats : list
        Ordered list of cluster categories.
    base_palette : dict
        {cluster: color} mapping for all clusters (muted).
    highlight_palette : dict
        {cluster: color} mapping for highlighted clusters (saturated).
    """
    sns.set_style("white")
    sns.set_context("talk")

    fig, ax = plt.subplots(figsize=(6, 5), dpi=300)

    # --- Plot non-highlighted clusters (background) ---
    non_highlight = [c for c in cats if c not in highlight_clusters]
    df_bg = df[df[cluster_col].isin(non_highlight)]
    if not df_bg.empty:
        sns.scatterplot(
            data=df_bg,
            x=x_col, y=y_col,
            hue=cluster_col,
            hue_order=non_highlight,
            palette={c: highlight_palette[c] for c in non_highlight},
            s=s_base,
            alpha=alpha_base,
            linewidth=0,
            legend=False,
            rasterized=rasterized,
            ax=ax,
        )

    # --- Plot highlighted clusters on top (foreground) ---
    df_fg = df[df[cluster_col].isin(highlight_clusters)]
    if not df_fg.empty:
        sns.scatterplot(
            data=df_fg,
            x=x_col, y=y_col,
            hue=cluster_col,
            hue_order=highlight_clusters,
            palette={c: highlight_palette[c] for c in highlight_clusters},
            s=s_highlight,
            alpha=alpha_highlight,
            linewidth=0,
            legend=False,
            rasterized=rasterized,
            ax=ax,
        )

    # --- Aesthetic tuning (Nature Genetics style) ---
    ax.invert_yaxis()
    ax.set_aspect("equal", adjustable="box")
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xlabel("")
    ax.set_ylabel("")
    for spine in ["top", "right", "left", "bottom"]:
        ax.spines[spine].set_visible(False)

    ax.set_title(title, pad=6, fontsize=12, fontweight="normal")

    plt.tight_layout()
    if savefig is not None:
        fig.savefig(savefig, bbox_inches='tight', dpi=250)
    plt.show()


In [None]:
# --- Choose clusters to highlight (string labels!)
highlight_clusters = ["1","2"]   # <-- adjust as needed

# --- Prep
emb = all_embeddings.copy()
emb["leiden_joint"] = emb["leiden_joint"].astype(str)  

# --- Global, ordered categories across ALL samples
all_labels = pd.Index(pd.unique(emb["leiden_joint"]))
def sort_key(lbl):
    return (0, int(lbl)) if lbl.isdigit() else (1, lbl)
cats = sorted(map(str, all_labels), key=sort_key)

from pandas.api.types import CategoricalDtype
cat_type = CategoricalDtype(categories=cats, ordered=True)
emb["leiden_joint"] = emb["leiden_joint"].astype(cat_type)

# --- Build highlight palette: keep original for highlights, gray for others
default_gray = "lightgray"
highlight_palette = {
    lbl: (palette_dict_1['leiden_joint'].get(lbl, default_gray) if lbl in highlight_clusters else default_gray)
    for lbl in cats
}

# --- Apply to ANY number of samples
for sample_id, df_sample in emb.groupby("sample_id", sort=True):
    plot_with_highlights(df_sample, sample_id, highlight_clusters, cats, palette_dict_1['leiden_joint'], highlight_palette, s_base=1,
    s_highlight=2, savefig=f'../../../SpatialFusion/results/figures_Fig5/{sample_id}_highlight_1_2.svg')


# Pathway activation

In [None]:
def summarize_cluster_pathways(
    adata_obs,
    pathways,
    cluster_col="kmeans_cluster",
    clusters=None,
    global_cluster_order=None
):
    """
    Returns a tidy DataFrame with mean pathway scores per cluster.
    Columns: Cluster, Pathway, mean, std, n, sem
    """
    if cluster_col not in adata_obs.columns:
        raise ValueError(f"'{cluster_col}' not found in adata.obs")

    missing = [p for p in pathways if p not in adata_obs.columns]
    if missing:
        raise ValueError(f"These pathways are missing in adata.obs: {missing}")

    df = adata_obs[[cluster_col] + pathways].copy()
    df[cluster_col] = df[cluster_col].astype(str)

    # cluster universe
    if global_cluster_order is None:
        all_clusters = sorted(df[cluster_col].unique())
    else:
        all_clusters = [str(c) for c in global_cluster_order]

    # subset clusters if requested
    if clusters is not None:
        selected = [str(c) for c in clusters]
    else:
        selected = all_clusters

    df = df[df[cluster_col].isin(selected)].copy()

    # melt -> summarize -> tidy
    long = df.melt(id_vars=[cluster_col], value_vars=pathways,
                   var_name="Pathway", value_name="Score").dropna(subset=["Score"])

    summary = (
        long.groupby([cluster_col, "Pathway"], as_index=False)
            .agg(mean=("Score", "mean"),
                 std =("Score", "std"),
                 n   =("Score", "size"))
    )
    summary["sem"] = summary["std"] / np.sqrt(summary["n"].clip(lower=1))
    summary.rename(columns={cluster_col: "Cluster"}, inplace=True)

    # keep a consistent cluster order
    summary["Cluster"] = pd.Categorical(summary["Cluster"], categories=selected, ordered=True)
    summary = summary.sort_values(["Cluster", "Pathway"]).reset_index(drop=True)
    return summary

In [None]:
def plot_pathway_bars_by_cluster(
    adata_obs,
    pathways,
    cluster_col="kmeans_cluster",
    clusters=None,
    global_cluster_order=None,
    err="sem",        # "sem" or "ci"
    ci_level=95, 
    figsize=None,
    ylim=None,
    palette=None,     # <-- NEW: user-specified palette dict or name
    savefig=None,
):
    """
    Grouped bar chart of mean pathway scores per pathway (x),
    colored by cluster (hue). Error bars: SEM or normal-approx CI.

    Parameters
    ----------
    palette : dict, list, or str, optional
        A seaborn-compatible color palette or dictionary mapping
        cluster labels to colors. If None, defaults to tab20.
    """
    from scipy.stats import norm

    summary = summarize_cluster_pathways(
        adata_obs, pathways, cluster_col=cluster_col,
        clusters=clusters, global_cluster_order=global_cluster_order
    )

    # keep the Pathway order as provided by `pathways`
    summary["Pathway"] = pd.Categorical(summary["Pathway"], categories=pathways, ordered=True)
    summary = summary.sort_values(["Pathway", "Cluster"]).reset_index(drop=True)

    # compute error quantity
    if err == "ci":
        z = norm.ppf(0.5 + ci_level/200.0)
        summary["yerr"] = z * summary["sem"]
        err_label = f"{ci_level}% CI"
    else:
        summary["yerr"] = summary["sem"]
        err_label = "SEM"

    # hue palette for clusters
    cluster_levels = summary["Cluster"].cat.categories.tolist()
    n_clusters = len(cluster_levels)

    if palette is None:
        hue_palette = sns.color_palette("tab20", n_colors=max(10, n_clusters))
        cluster_palette = {c: hue_palette[i % len(hue_palette)] for i, c in enumerate(cluster_levels)}
    elif isinstance(palette, dict):
        cluster_palette = palette
    else:
        # can be a seaborn palette name or list
        cluster_palette = sns.color_palette(palette, n_colors=n_clusters)

    # figure size scales with #pathways
    if figsize is None:
        figsize = (max(8, 0.8 * len(pathways)), 5 + 0.15 * n_clusters)

    plt.figure(figsize=figsize)
    ax = sns.barplot(
        data=summary,
        x="Pathway", y="mean",
        hue="Cluster",
        palette=cluster_palette,
        dodge=True,
        errorbar=None
    )

    # manual error bars
    x_categories = list(summary["Pathway"].cat.categories)
    n_x = len(x_categories)
    n_hue = n_clusters
    bar_width = 0.8 / max(1, n_hue)
    x_positions = {p: i for i, p in enumerate(x_categories)}

    for j, c in enumerate(cluster_levels):
        sub = summary[summary["Cluster"] == c].reset_index(drop=True)
        sub = sub.set_index("Pathway").reindex(x_categories).reset_index()
        xs = [x_positions[p] - 0.4 + (j + 0.5) * bar_width for p in sub["Pathway"]]
        ax.errorbar(xs, sub["mean"], yerr=sub["yerr"], fmt="none",
                    ecolor="black", elinewidth=1, capsize=3, capthick=1)

    ax.set_title(f"Mean pathway score per pathway (± {err_label})")
    ax.set_ylabel("Mean pathway score")
    ax.set_xlabel("Pathway")
    if ylim is not None:
        ax.set_ylim(ylim)
    ax.legend(title="Cluster", bbox_to_anchor=(1.02, 1), loc="upper left")
    plt.tight_layout()
    if savefig is not None:
        plt.savefig(savefig, dpi=200, bbox_inches='tight')
    plt.show()

    return summary


In [None]:
path_mat = {}
for sample in sample_list:
    path_mat[sample] = pd.read_parquet(base_dir / sample / 'pathway_activation.parquet')
    path_mat[sample].index = path_mat[sample].index.astype(str) + '::' + sample

path_mat = pd.concat(path_mat.values())

adata.obs = pd.concat([adata.obs, path_mat.loc[adata.obs_names]],axis=1)

In [None]:
pathways = ['Androgen','EGFR','Estrogen','JAK-STAT','MAPK','NFkB','PI3K','TGFb','TNFa','VEGF']

plot_pathway_bars_by_cluster(
    adata.obs,
    pathways=pathways,
    cluster_col="leiden_joint",
    global_cluster_order=['0','1','2','3','4','5','6','7','8','9','10'],
    palette=palette_dict_1['leiden_joint'],
    err="ci",        # "sem" or "ci"
    ci_level=95, 
    figsize=(15,5),
    savefig='../../../SpatialFusion/results/figures_Fig5/pathway_score_1_barplot.svg'
)


In [None]:
pathways = ['Androgen','EGFR','Estrogen','JAK-STAT','MAPK','NFkB','PI3K','TGFb','TNFa','VEGF']

plot_pathway_bars_by_cluster(
    adata.obs,
    pathways=pathways,
    cluster_col="leiden_joint",
    global_cluster_order=['11','12','13','14','15','16','17',],
    palette=palette_dict_1['leiden_joint'],
    err="ci",        # "sem" or "ci"
    ci_level=95, 
    figsize=(15,5),
    savefig='../../../SpatialFusion/results/figures_Fig5/pathway_score_2_barplot.svg'
)


In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.stats import mannwhitneyu
import pandas as pd
import numpy as np

def plot_pathway_with_significance(
    adata_obs,
    pathway="Androgen",
    sample_col="sample_id",
    cluster_col="leiden_joint",
    hue_order=["1", "2"],
    order=None,
    palette=None,
    figsize=(10, 4),
    ymax=4, # where to add the line
    savefig=None,
):
    """
    Plot per-patient barplots of pathway scores (hue = cluster)
    and annotate significance per patient using MWU test.
    """
    # Prepare dataframe
    df = adata_obs[[sample_col, cluster_col, pathway]].copy()
    df[sample_col] = df[sample_col].astype(str)
    df[cluster_col] = df[cluster_col].astype(str)

    # Default order = sorted unique samples
    if order is None:
        order = sorted(df[sample_col].unique())

    # Create figure and barplot
    fig, ax = plt.subplots(1, 1, figsize=figsize)
    sns.barplot(
        data=df,
        x=sample_col,
        y=pathway,
        hue=cluster_col,
        hue_order=hue_order,
        order=order,
        palette=palette,
        dodge=True,
        estimator=np.mean,
        errorbar=None,
        ax=ax
    )

    # Compute significance test per sample in the specified order
    results = []
    for sample in order:
        sub = df[df[sample_col] == sample]
        g1 = sub.loc[sub[cluster_col] == hue_order[0], pathway]
        g2 = sub.loc[sub[cluster_col] == hue_order[1], pathway]
        if len(g1) > 0 and len(g2) > 0:
            u, p = mannwhitneyu(g1, g2, alternative="two-sided")
            results.append((sample, p))
        else:
            results.append((sample, np.nan))

    # Define significance label thresholds
    def p_to_star(p):
        if p < 0.001:
            return "***"
        elif p < 0.01:
            return "**"
        elif p < 0.05:
            return "*"
        else:
            return "ns"

    # Add significance stars above each pair of bars

    for i, (sample, p) in enumerate(results):
        if np.isnan(p):
            continue
        star = p_to_star(p)

        # Calculate x positions for the two bars in this group
        bar_width = 0.8 / len(hue_order)
        x_center = i
        x1 = x_center - 0.4 + (0.5 * bar_width)
        x2 = x_center + 0.4 - (0.5 * bar_width)

        # Line height for this sample
        ax.plot([x1, x2], [ymax, ymax], color='black', lw=1)
        ax.text(x_center, ymax, star, ha='center', va='bottom', fontsize=10)

    ax.set_title(f"{pathway} pathway scores by patient (MWU significance)")
    ax.set_xlabel("Patient")
    ax.set_ylabel("Pathway score")
    plt.xticks(rotation=90)
    ax.legend(title="Cluster", bbox_to_anchor=(1.02, 1), loc="upper left")
    plt.tight_layout()
    if savefig is not None:
        fig.savefig(savefig, dpi=200, bbox_inches='tight')
        print(f'Saved figure at {savefig}.')
    plt.show()

    # Return results table
    return pd.DataFrame(results, columns=["sample_id", "p_value"])


In [None]:
plot_pathway_with_significance(
    adata.obs,
    pathway="Androgen",
    sample_col="sample_id",
    cluster_col="leiden_joint",
    order=['P1CRC','P2CRC','P5CRC','P3NAT','P5NAT'],
    hue_order=["1", "2"],
    palette=palette_dict_1['leiden_joint'],
    figsize=(5, 2),
    ymax=4,
    savefig = '../../../SpatialFusion/results/figures_Fig5/androgen_perpatient_niche12.svg',
)

In [None]:
plot_pathway_with_significance(
    adata.obs,
    pathway="JAK-STAT",
    sample_col="sample_id",
    cluster_col="leiden_joint",
    order=['P1CRC','P2CRC','P5CRC','P3NAT','P5NAT'],
    hue_order=["1", "2"],
    palette=palette_dict_1['leiden_joint'],
    figsize=(5, 2),
    ymax=4,
    savefig = '../../../SpatialFusion/results/figures_Fig5/JAK-STAT_perpatient_niche12.svg',
)

# Transform adata

In [None]:
adata.layers['counts'] = adata.X.copy()

sc.pp.normalize_total(adata, target_sum=10000)
sc.pp.log1p(adata)

In [None]:
def plot_annotation(
    ax,
    adata,
    column,
    title,
    palette,
    vmin=-5,
    vmax=5,
    x_key="X_coord",
    y_key="Y_coord",
    point_size=3,
    xlim=None,
    ylim=None,
    colorbar_info=None,
    legends_info=None,
    rasterize_points=True,
):
    """
    Plot an annotation layer (continuous or categorical) in Nature Genetics style.

    Parameters
    ----------
    ax : matplotlib Axes
        Axis to draw on.
    adata : AnnData
        AnnData object containing .obs[column] and coordinates.
    column : str
        Column name in adata.obs to visualize.
    title : str
        Panel title to display.
    palette : str, dict, or colormap
        Palette for categorical data or cmap for continuous data.
    vmin, vmax : float
        Limits for continuous color scaling (clipping).
    x_key, y_key : str
        Keys for spatial coordinates.
    point_size : float
        Scatter point size.
    xlim, ylim : tuple or None
        Manual axis limits if needed.
    colorbar_info : list
        Collects (scatter, title) tuples for later global colorbar plotting.
    legends_info : list
        Collects (title, handles, labels) tuples for separate legends.
    rasterize_points : bool
        Rasterize scatter for smaller vector file size.
    """

    # --- verify column
    if column not in adata.obs.columns:
        raise KeyError(f"Column '{column}' not found in adata.obs")

    values = adata.obs[column]
    if isinstance(values, pd.DataFrame):
        raise ValueError(f"Column '{column}' is not unique in adata.obs (multiple matches)")

    # --- style
    sns.set_style("white")

    # === Continuous variable ===
    if pd.api.types.is_numeric_dtype(values):
        clipped_values = values.clip(lower=vmin, upper=vmax)

        sc = ax.scatter(
            adata.obs[x_key],
            adata.obs[y_key],
            c=clipped_values.values,
            cmap=palette,
            s=point_size,
            alpha=0.8,
            linewidth=0,
            vmin=vmin,
            vmax=vmax,
            rasterized=rasterize_points,
        )

        if colorbar_info is not None:
            colorbar_info.append((sc, title))

    # === Categorical variable ===
    else:
        hue_vals = values.astype(str)

        sns.scatterplot(
            x=adata.obs[x_key],
            y=adata.obs[y_key],
            hue=hue_vals,
            palette=palette,
            s=point_size,
            linewidth=0,
            alpha=0.9,
            ax=ax,
            legend=False,
            rasterized=rasterize_points,
        )

        # collect legend entries for later legend-only figure
        handles, labels = ax.get_legend_handles_labels()
        by_label = {}
        for h, l in zip(handles, labels):
            if l and l != "_nolegend_" and l not in by_label:
                by_label[l] = h
        if legends_info is not None:
            legends_info.append((title, list(by_label.values()), list(by_label.keys())))

    # --- general aesthetics ---
    ax.set_title(title, pad=10)
    ax.set_aspect("equal")
    

    if xlim is not None:
        ax.set_xlim(xlim)
    if ylim is not None:
        ax.set_ylim(ylim)
    ax.invert_yaxis()

    # no axes, ticks, or spines
    ax.set_xticks([])
    ax.set_yticks([])
    for spine in ["top", "right", "left", "bottom"]:
        ax.spines[spine].set_visible(False)
    ax.grid(False)


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

def plot_spatial_with_HE(
    wsi,
    adata,
    sample_name,
    palette,
    xlim,
    ylim,
    x_col="X_coord",
    y_col="Y_coord",
    cluster_col="leiden_joint",
    point_size=7,
    he_alpha=0.9,
    scatter_alpha=0.85,
    figsize=(7, 7),
    dpi=300,
    show=True,
):
    """
    Plot a cropped region of an H&E image with spatial cluster overlay
    in the clean, minimalist 'Nature Genetics' style.

    Parameters
    ----------
    wsi : np.ndarray
        Full-resolution H&E image (RGB).
    adata : AnnData
        Annotated data object with spatial coordinates.
    sample_name : str
        Sample key to select adata[sample_name].
    palette : dict
        Palette, e.g. palette_dict['leiden_joint'].
    xlim, ylim : tuple of int
        Crop window in pixel coordinates (H&E space).
    x_col, y_col : str
        Column names for spatial coordinates.
    cluster_col : str
        Column name for cluster annotation in adata.obs.
    point_size : int or float
        Marker size for scatter points.
    he_alpha : float
        Transparency of the H&E image (0–1).
    scatter_alpha : float
        Transparency of scatter overlay (0–1).
    figsize : tuple
        Figure size in inches.
    dpi : int
        Output DPI.
    show : bool
        Whether to display the plot immediately.
    """

    # --- Crop ROI ---
    x0, x1 = xlim
    y0, y1 = ylim
    roi = wsi[y0:y1, x0:x1, :]  # numpy is [row=y, col=x]

    # --- Set up figure ---
    sns.set_style("white")
    sns.set_context("talk")
    fig, ax = plt.subplots(figsize=figsize, dpi=dpi)

    # --- Plot H&E background ---
    ax.imshow(
        roi,
        origin="upper",
        extent=(x0, x1, y1, y0),
        alpha=he_alpha
    )

    # --- Overlay cluster scatter ---
    sns.scatterplot(
        data=adata.obs,
        x=x_col, y=y_col,
        hue=cluster_col,
        palette=palette,
        s=point_size,
        alpha=scatter_alpha,
        linewidth=0,
        rasterized=True,
        ax=ax,
        legend=False
    )

    # --- Crop and formatting ---
    ax.set_xlim(x0, x1)
    ax.set_ylim(y1, y0)
    ax.set_aspect("equal", adjustable="box")
    ax.invert_yaxis()

    ax.set_xticks([]); ax.set_yticks([])
    ax.set_xlabel(""); ax.set_ylabel("")
    for spine in ["top", "right", "left", "bottom"]:
        ax.spines[spine].set_visible(False)

    plt.tight_layout()

    if show:
        plt.show()

    return fig, ax


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import matplotlib.tri as mtri
from matplotlib.tri import TriAnalyzer


# --- helper: build triangulation ---
def _build_triangulation(df, x_col='X_coord', y_col='Y_coord', min_circle_ratio=0.01):
    x = df[x_col].values
    y = df[y_col].values
    if x.size < 3:
        return None
    tri = mtri.Triangulation(x, y)
    try:
        mask = TriAnalyzer(tri).get_flat_tri_mask(min_circle_ratio=min_circle_ratio)
        tri.set_mask(mask)
    except Exception:
        pass
    return tri


# --- helper: draw non-convex outlines via tricontour ---
def _overlay_niches_tricontour(
    ax,
    df,
    tri,
    niche_col='leiden_joint',
    palette=None,
    highlight_niches=None,
    eps=0.02,
    line_width=1.8,
    halo_width=3.0,
    outline_alpha=1.0,
    zorder_base=10,
):
    """
    Draw one or more disconnected non-convex outlines for each niche
    using tricontour over a binary mask of niche labels.
    """

    if tri is None:
        return

    # Get cluster labels for triangulation points
    labels = df[niche_col].astype(str).values

    # Which niches to draw (all or subset)
    if highlight_niches is None:
        lbls = sorted(np.unique(labels))
    else:
        lbls = [str(l) for l in highlight_niches]

    # Define color per label
    color_by_label = {}
    for lbl in lbls:
        if palette is not None and lbl in palette:
            color_by_label[lbl] = palette[lbl]
        else:
            # fallback color
            color_by_label[lbl] = "gray"

    # Small ±eps shift avoids numerical overlap artifacts
    level_for = {lbl: 0.5 + (eps if (i % 2) else -eps)
                 for i, lbl in enumerate(lbls)}

    # Draw contours
    for lbl in lbls:
        z = (labels == lbl).astype(float)

        # white halo first
        if halo_width and halo_width > 0:
            ax.tricontour(
                tri, z,
                levels=[level_for[lbl]],
                colors=['white'],
                linewidths=halo_width,
                alpha=outline_alpha,
                zorder=zorder_base
            )

        # then the colored outline
        ax.tricontour(
            tri, z,
            levels=[level_for[lbl]],
            colors=[color_by_label[lbl]],
            linewidths=line_width,
            alpha=outline_alpha,
            zorder=zorder_base + 1
        )


# --- main plot function ---
def plot_spatial_with_HE_outline(
    wsi,
    adata,
    sample_name,
    palette,
    xlim,
    ylim,
    x_col="X_coord",
    y_col="Y_coord",
    cluster_col="leiden_joint",
    he_alpha=0.9,
    outline_alpha=1.0,
    outline_width=2,
    halo_width=3.0,
    figsize=(7, 7),
    dpi=300,
    min_circle_ratio=0.01,
    highlight_niches=None,  # e.g. ['0','2','5']
    savefig=None,
    show=True,
):
    """
    Plot H&E region with non-convex niche outlines (tricontour) and consistent palette colors.
    """

    # --- Crop ROI ---
    x0, x1 = xlim
    y0, y1 = ylim
    roi = wsi[y0:y1, x0:x1, :]  # numpy is [row=y, col=x]

    # --- Subset adata to ROI ---
    df = adata.obs.copy()
    df = df[(df[x_col] >= x0) & (df[x_col] <= x1) &
            (df[y_col] >= y0) & (df[y_col] <= y1)]

    if df.empty:
        raise ValueError("No cells found in ROI")

    # --- Build triangulation for this ROI ---
    tri = _build_triangulation(df, x_col=x_col, y_col=y_col, min_circle_ratio=min_circle_ratio)

    # --- Figure setup ---
    sns.set_style("white")
    sns.set_context("talk")
    fig, ax = plt.subplots(figsize=figsize, dpi=dpi)

    # --- H&E background ---
    ax.imshow(
        roi,
        origin="upper",
        extent=(x0, x1, y1, y0),
        alpha=he_alpha
    )

    # --- Overlay niche outlines (with provided palette + optional filtering) ---
    _overlay_niches_tricontour(
        ax, df, tri,
        niche_col=cluster_col,
        palette=palette,
        highlight_niches=highlight_niches,
        eps=0.02,
        line_width=outline_width,
        halo_width=halo_width,
        outline_alpha=outline_alpha,
        zorder_base=10,
    )

    # --- Formatting ---
    ax.set_xlim(x0, x1)
    ax.set_ylim(y1, y0)
    ax.set_aspect("equal", adjustable="box")
    ax.invert_yaxis()

    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xlabel("")
    ax.set_ylabel("")
    for spine in ["top", "right", "left", "bottom"]:
        ax.spines[spine].set_visible(False)

    plt.tight_layout()
    if savefig is not None:
        fig.savefig(savefig, bbox_inches='tight', dpi=dpi)
    if show:
        plt.show()

    return fig, ax


In [None]:
def plot_annotation_wo_HE(adata, ax, column, title, palette):
    """Single Nature Genetics–style scatter panel."""
    hue_vals = adata.obs[column].astype(str)

    # First call with legend=True on a hidden temporary axis to grab handles
    tmp_ax = plt.figure().add_subplot(111)
    sns.scatterplot(
        x=adata.obs["X_coord"],
        y=adata.obs["Y_coord"],
        hue=hue_vals,
        palette=palette,
        s=point_size,
        ax=tmp_ax,
        linewidth=0,
        alpha=0.9,
        legend=True,
    )
    handles, labels = tmp_ax.get_legend_handles_labels()
    plt.close(tmp_ax.figure)

    # Deduplicate and clean labels
    by_label = {}
    for h, l in zip(handles, labels):
        if l and l != "_nolegend_" and l not in by_label:
            by_label[l] = h
    legends_info.append((title, list(by_label.values()), list(by_label.keys())))

    # Actual subplot (no legend, rasterized)
    sns.scatterplot(
        x=adata.obs["X_coord"],
        y=adata.obs["Y_coord"],
        hue=hue_vals,
        palette=palette,
        s=point_size,
        ax=ax,
        linewidth=0,
        alpha=0.9,
        legend=False,
        rasterized=True,
    )

    
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    ax.invert_yaxis()
    ax.set_xlabel('')
    ax.set_ylabel('')
    ax.set_aspect("equal")

    ax.set_title(title, pad=6, fontweight="normal")
    ax.set_xticks([]); ax.set_yticks([])
    for spine in ["top", "right", "left", "bottom"]:
        ax.spines[spine].set_visible(False)
    ax.grid(False)


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import matplotlib.tri as mtri
from matplotlib.tri import TriAnalyzer


# ---------- triangulation + outline helpers ----------

def _build_triangulation(df, x_col='X_coord', y_col='Y_coord', min_circle_ratio=0.01):
    x = df[x_col].values
    y = df[y_col].values
    if x.size < 3:
        return None
    tri = mtri.Triangulation(x, y)
    try:
        mask = TriAnalyzer(tri).get_flat_tri_mask(min_circle_ratio=min_circle_ratio)
        tri.set_mask(mask)
    except Exception:
        pass
    return tri


def _overlay_niches_tricontour(
    ax,
    df,
    tri,
    niche_col='leiden_joint',
    palette=None,
    highlight_niches=None,
    eps=0.02,
    line_width=1.8,
    halo_width=3.0,
    outline_alpha=1.0,
    zorder_base=10,
):
    """Draw disconnected, non-convex outlines for each niche."""
    if tri is None:
        return

    labels = df[niche_col].astype(str).values

    if highlight_niches is None:
        lbls = sorted(np.unique(labels))
    else:
        lbls = [str(l) for l in highlight_niches]

    color_by_label = {}
    for lbl in lbls:
        if palette is not None and lbl in palette:
            color_by_label[lbl] = palette[lbl]
        else:
            color_by_label[lbl] = "gray"

    level_for = {lbl: 0.5 + (eps if (i % 2) else -eps)
                 for i, lbl in enumerate(lbls)}

    for lbl in lbls:
        z = (labels == lbl).astype(float)
        if halo_width and halo_width > 0:
            ax.tricontour(
                tri, z, levels=[level_for[lbl]],
                colors=['white'], linewidths=halo_width,
                alpha=outline_alpha, zorder=zorder_base)
        ax.tricontour(
            tri, z, levels=[level_for[lbl]],
            colors=[color_by_label[lbl]], linewidths=line_width,
            alpha=outline_alpha, zorder=zorder_base + 1)


# ---------- main plotting function with outline ----------

def plot_annotation_pathway(
    adata,
    ax,
    column,
    title,
    palette,
    vmin=-5,
    vmax=5,
    point_size=2,
    xlim=None,
    ylim=None,
    colorbar_info=None,
    legends_info=None,
    # --- new outline options ---
    cluster_col=None,
    palette_niches=None,
    highlight_niches=None,
    outline_width=2.0,
    halo_width=3.0,
    outline_alpha=1.0,
    min_circle_ratio=0.01,
):
    """
    Nature Genetics–style overlay for continuous/categorical variables,
    with optional niche outlines from triangulation.
    """

    sns.set_style("white")
    sns.set_context("talk")

    # --- check data column ---
    if column not in adata.obs.columns:
        raise KeyError(f"Column '{column}' not found in adata.obs")

    df = adata.obs
    values = df[column]
    if isinstance(values, pd.DataFrame):
        raise ValueError(f"Column '{column}' has multiple matches in adata.obs")

    # --- determine visible region (optional crop) ---
    if xlim is not None and ylim is not None:
        x0, x1 = xlim
        y0, y1 = ylim
        view_mask = df["X_coord"].between(x0, x1) & df["Y_coord"].between(y0, y1)
        df_view = df.loc[view_mask].copy()
    else:
        df_view = df

    # --- continuous variable ---
    if pd.api.types.is_numeric_dtype(values):
        clipped_values = values.clip(lower=vmin, upper=vmax)
        sc = ax.scatter(
            df["X_coord"], df["Y_coord"],
            c=clipped_values.values,
            cmap=palette,
            s=point_size,
            alpha=0.8,
            linewidth=0,
            vmin=vmin, vmax=vmax,
            rasterized=True,
        )
        if colorbar_info is not None:
            colorbar_info.append((sc, title))

    # --- categorical variable ---
    else:
        sns.scatterplot(
            data=df,
            x="X_coord", y="Y_coord",
            hue=column, palette=palette,
            s=point_size, ax=ax,
            linewidth=0, alpha=0.9,
            legend=False, rasterized=True,
        )
        if legends_info is not None:
            handles, labels = ax.get_legend_handles_labels()
            by_label = {l: h for h, l in zip(handles, labels)
                        if l and l != "_nolegend_"}
            legends_info.append((title,
                                 list(by_label.values()),
                                 list(by_label.keys())))

    # --- add niche outlines if requested ---
    if cluster_col in df.columns and palette_niches is not None:
        tri = _build_triangulation(
            df_view, x_col='X_coord', y_col='Y_coord',
            min_circle_ratio=min_circle_ratio
        )
        _overlay_niches_tricontour(
            ax, df_view, tri,
            niche_col=cluster_col,
            palette=palette_niches,
            highlight_niches=highlight_niches,
            line_width=outline_width,
            halo_width=halo_width,
            outline_alpha=outline_alpha,
            zorder_base=20,
        )

    # --- cleanup aesthetics ---
    if xlim is not None:
        ax.set_xlim(xlim)
    if ylim is not None:
        ax.set_ylim(ylim)
    ax.invert_yaxis()
    ax.set_aspect("equal")
    ax.set_xticks([]); ax.set_yticks([])
    ax.set_xlabel(""); ax.set_ylabel("")
    ax.set_title(title, fontsize=16, fontweight="normal", pad=6)

    for spine in ["top", "right", "left", "bottom"]:
        ax.spines[spine].set_visible(False)
    ax.grid(False)


In [None]:
import tifffile
import numpy as np
from pathlib import Path
source_image_paths = {'P1CRC': '../../../Broad_SpatialFoundation/VisiumHD-CRC/P1CRC/Visium_HD_Human_Colon_Cancer_P1_tissue_image.btf',
                     'P2CRC': '../../../Broad_SpatialFoundation/VisiumHD-CRC/P2CRC/Visium_HD_Human_Colon_Cancer_P2_tissue_image.btf',
                     'P5CRC': '../../../Broad_SpatialFoundation/VisiumHD-CRC/P5CRC/Visium_HD_Human_Colon_Cancer_P5_tissue_image.btf',
                     'P3NAT': '../../../Broad_SpatialFoundation/VisiumHD-CRC/P3NAT/Visium_HD_Human_Colon_Normal_P3_tissue_image.btf',
                     'P5NAT': '../../../Broad_SpatialFoundation/VisiumHD-CRC/P5NAT/Visium_HD_Human_Colon_Normal_P5_tissue_image.btf',}


In [None]:
sample_list

In [None]:
adatas = {spl: adata[adata.obs['sample_id']==spl].copy() for spl in sample_list}

In [None]:
# === Parameters ===
xlim = (0, 25000)
ylim = (10000, 38000)
point_size = 3
figsize = (13, 25)
save_prefix = "../../../SpatialFusion/results/figures_Fig5/P1CRC_pathway_activity"

# === Plot settings ===
plot_configs = [
    ("Androgen", "Androgen Activity", "vlag"),
    ("EGFR", "EGFR Activity", "vlag"),
    ("Estrogen", "Estrogen Activity", "vlag"), 
    ("JAK-STAT", "JAK-STAT Activity", "vlag"),
    ("MAPK", "MAPK Activity", "vlag"),
    ("NFkB", "NFkB Activity", "vlag"),
    ("PI3K", "PI3K Activity", "vlag"),
    ("TGFb", "TGFb Activity", "vlag"),
    ("TNFa", "TNFa Activity", "vlag"),
    ("VEGF", "VEGF Activity", "vlag"),
 
]

# === Create subplots ===
fig, axes = plt.subplots(5, 2, figsize=figsize, sharex=True, sharey=True)
axes = axes.flatten()

legends_info = []
colorbar_info = []  # store (ax, column, cmap) for continuous plots

# === Generate plots ===
for ax, (col, title, pal) in zip(axes, plot_configs):
    plot_annotation_pathway(adatas['P1CRC'], ax, col, title, pal, xlim=xlim,
    ylim=ylim,)

plt.tight_layout()
plt.show()

fig.savefig(f"{save_prefix}_panel.svg", dpi=150, bbox_inches="tight", transparent=True)
print(f"✅ Saved panel: {save_prefix}_panel.svg")

In [None]:
# === Parameters ===
xlim = (40000, 65000)
ylim = (0, 25000)
point_size = 3
figsize = (13, 25)
save_prefix = "../../../SpatialFusion/results/figures_Fig5/P2CRC_pathway_activity"

# === Plot settings ===
plot_configs = [
    ("Androgen", "Androgen Activity", "vlag"),
    ("EGFR", "EGFR Activity", "vlag"),
    ("Estrogen", "Estrogen Activity", "vlag"), 
    ("JAK-STAT", "JAK-STAT Activity", "vlag"),
    ("MAPK", "MAPK Activity", "vlag"),
    ("NFkB", "NFkB Activity", "vlag"),
    ("PI3K", "PI3K Activity", "vlag"),
    ("TGFb", "TGFb Activity", "vlag"),
    ("TNFa", "TNFa Activity", "vlag"),
    ("VEGF", "VEGF Activity", "vlag"),
 
]

# === Create subplots ===
fig, axes = plt.subplots(5, 2, figsize=figsize, sharex=True, sharey=True)
axes = axes.flatten()

legends_info = []
colorbar_info = []  # store (ax, column, cmap) for continuous plots

# === Generate plots ===
for ax, (col, title, pal) in zip(axes, plot_configs):
    plot_annotation_pathway(adatas['P2CRC'], ax, col, title, pal, xlim=xlim,
    ylim=ylim,)

plt.tight_layout()
plt.show()

fig.savefig(f"{save_prefix}_panel.svg", dpi=150, bbox_inches="tight", transparent=True)
print(f"✅ Saved panel: {save_prefix}_panel.svg")

In [None]:
# === Parameters ===
xlim = (45000, 65000)
ylim = (40000, 65000)
point_size = 3
figsize = (13, 25)
save_prefix = "../../../SpatialFusion/results/figures_Fig5/P5CRC_pathway_activity"

# === Plot settings ===
plot_configs = [
    ("Androgen", "Androgen Activity", "vlag"),
    ("EGFR", "EGFR Activity", "vlag"),
    ("Estrogen", "Estrogen Activity", "vlag"), 
    ("JAK-STAT", "JAK-STAT Activity", "vlag"),
    ("MAPK", "MAPK Activity", "vlag"),
    ("NFkB", "NFkB Activity", "vlag"),
    ("PI3K", "PI3K Activity", "vlag"),
    ("TGFb", "TGFb Activity", "vlag"),
    ("TNFa", "TNFa Activity", "vlag"),
    ("VEGF", "VEGF Activity", "vlag"),
 
]

# === Create subplots ===
fig, axes = plt.subplots(5, 2, figsize=figsize, sharex=True, sharey=True)
axes = axes.flatten()

legends_info = []
colorbar_info = []  # store (ax, column, cmap) for continuous plots

# === Generate plots ===
for ax, (col, title, pal) in zip(axes, plot_configs):
    plot_annotation_pathway(adatas['P5CRC'], ax, col, title, pal, xlim=xlim,
    ylim=ylim,)

plt.tight_layout()
plt.show()

fig.savefig(f"{save_prefix}_panel.svg", dpi=150, bbox_inches="tight", transparent=True)
print(f"✅ Saved panel: {save_prefix}_panel.svg")

In [None]:
# === Parameters ===
xlim = (0, 22000)
ylim = (10000, 35000)
point_size = 3
figsize = (13, 25)
save_prefix = "../../../SpatialFusion/results/figures_Fig5/P3NAT_pathway_activity"

# === Plot settings ===
plot_configs = [
    ("Androgen", "Androgen Activity", "vlag"),
    ("EGFR", "EGFR Activity", "vlag"),
    ("Estrogen", "Estrogen Activity", "vlag"), 
    ("JAK-STAT", "JAK-STAT Activity", "vlag"),
    ("MAPK", "MAPK Activity", "vlag"),
    ("NFkB", "NFkB Activity", "vlag"),
    ("PI3K", "PI3K Activity", "vlag"),
    ("TGFb", "TGFb Activity", "vlag"),
    ("TNFa", "TNFa Activity", "vlag"),
    ("VEGF", "VEGF Activity", "vlag"),
 
]

# === Create subplots ===
fig, axes = plt.subplots(5, 2, figsize=figsize, sharex=True, sharey=True)
axes = axes.flatten()

legends_info = []
colorbar_info = []  # store (ax, column, cmap) for continuous plots

# === Generate plots ===
for ax, (col, title, pal) in zip(axes, plot_configs):
    plot_annotation_pathway(adatas['P3NAT'], ax, col, title, pal, xlim=xlim,
    ylim=ylim,)

plt.tight_layout()
plt.show()

fig.savefig(f"{save_prefix}_panel.svg", dpi=150, bbox_inches="tight", transparent=True)
print(f"✅ Saved panel: {save_prefix}_panel.svg")

In [None]:
# === Parameters ===
xlim = (40000, 68000)
ylim = (25000, 53000)
point_size = 3
figsize = (13, 25)
save_prefix = "../../../SpatialFusion/results/figures_Fig5/P5NAT_pathway_activity"

# === Plot settings ===
plot_configs = [
    ("Androgen", "Androgen Activity", "vlag"),
    ("EGFR", "EGFR Activity", "vlag"),
    ("Estrogen", "Estrogen Activity", "vlag"), 
    ("JAK-STAT", "JAK-STAT Activity", "vlag"),
    ("MAPK", "MAPK Activity", "vlag"),
    ("NFkB", "NFkB Activity", "vlag"),
    ("PI3K", "PI3K Activity", "vlag"),
    ("TGFb", "TGFb Activity", "vlag"),
    ("TNFa", "TNFa Activity", "vlag"),
    ("VEGF", "VEGF Activity", "vlag"),
 
]

# === Create subplots ===
fig, axes = plt.subplots(5, 2, figsize=figsize, sharex=True, sharey=True)
axes = axes.flatten()

legends_info = []
colorbar_info = []  # store (ax, column, cmap) for continuous plots

# === Generate plots ===
for ax, (col, title, pal) in zip(axes, plot_configs):
    plot_annotation_pathway(adatas['P5NAT'], ax, col, title, pal, xlim=xlim,
    ylim=ylim,)

plt.tight_layout()
plt.show()

fig.savefig(f"{save_prefix}_panel.svg", dpi=150, bbox_inches="tight", transparent=True)
print(f"✅ Saved panel: {save_prefix}_panel.svg")

# Visualize pathology

In [None]:
sample_name = 'P1CRC'

print('Opening WSI')
with tifffile.TiffFile(source_image_paths[sample_name]) as tif:
    wsi = tif.series[0].asarray()

In [None]:
# --- Region of interest in H&E pixel space ---
xlim = (5000, 10000)
ylim = (22000, 27000)
x0, x1 = xlim
y0, y1 = ylim

# --- Set up figure ---
sns.set_style("white")
sns.set_context("talk")

plot_spatial_with_HE_outline(
    wsi,
    adatas[sample_name],
    sample_name,
    palette_dict_1['leiden_joint'],
    xlim,
    ylim,
    x_col="X_coord",
    y_col="Y_coord",
    cluster_col="leiden_joint",
    he_alpha=0.9,  
    outline_alpha=0.9,
    outline_width=3.0,
    highlight_niches=['1','2'],
    figsize=(7, 7),
    dpi=300,
    savefig=f'../../../SpatialFusion/results/figures_Fig5/{sample_name}_outline_zoom_niche12.svg',
    show=True,
)


In [None]:
# === Parameters ===
xlim = (5000, 10000)
ylim = (22000, 27000)
point_size = 3
figsize = (5, 5)
save_prefix = f"../../../SpatialFusion/results/figures_Fig5/{sample_name}_androgen_activity_zoom"

# === Plot settings ===
plot_configs = [
    ("Androgen", "Androgen Activity", "vlag"),
]

# === Create subplots ===
fig, axes = plt.subplots(1, 1, figsize=figsize, sharex=True, sharey=True)
axes = [axes]

legends_info = []
colorbar_info = []  # store (ax, column, cmap) for continuous plots

# === Generate plots ===
for ax, (col, title, pal) in zip(axes, plot_configs):
    plot_annotation_pathway(adatas[sample_name], ax, col, title, pal, xlim=xlim,
    ylim=ylim, vmin=-1, vmax=3, cluster_col='leiden_joint',
                            palette_niches=palette_dict_1['leiden_joint'], highlight_niches=['1','2'])

plt.tight_layout()
plt.show()

fig.savefig(f"{save_prefix}_panel.svg", dpi=150, bbox_inches="tight", transparent=True)
print(f"✅ Saved panel: {save_prefix}_panel.svg")

In [None]:
# === Parameters ===
xlim = (5000, 10000)
ylim = (22000, 27000)
point_size = 3
figsize = (5, 5)
save_prefix = f"../../../SpatialFusion/results/figures_Fig5/{sample_name}_JAK-STAT_activity_zoom"

# === Plot settings ===
plot_configs = [
    ("JAK-STAT", "JAK-STAT Activity", "vlag"),
]

# === Create subplots ===
fig, axes = plt.subplots(1, 1, figsize=figsize, sharex=True, sharey=True)
axes = [axes]

legends_info = []
colorbar_info = []  # store (ax, column, cmap) for continuous plots

# === Generate plots ===
for ax, (col, title, pal) in zip(axes, plot_configs):
    plot_annotation_pathway(adatas[sample_name], ax, col, title, pal, xlim=xlim,
    ylim=ylim, vmin=-1, vmax=3, cluster_col='leiden_joint',
                            palette_niches=palette_dict_1['leiden_joint'], highlight_niches=['1','2'])

plt.tight_layout()
plt.show()

fig.savefig(f"{save_prefix}_panel.svg", dpi=150, bbox_inches="tight", transparent=True)
print(f"✅ Saved panel: {save_prefix}_panel.svg")

In [None]:
sample_name = 'P2CRC'

with tifffile.TiffFile(source_image_paths[sample_name]) as tif:
    wsi = tif.series[0].asarray()

wsi.shape

In [None]:
# --- Region of interest in H&E pixel space ---
xlim = (53000, 58000)
ylim = (8000, 13000)
x0, x1 = xlim
y0, y1 = ylim

# --- Set up figure ---
sns.set_style("white")
sns.set_context("talk")

plot_spatial_with_HE_outline(
    wsi,
    adatas[sample_name],
    sample_name,
    palette_dict_1['leiden_joint'],
    xlim,
    ylim,
    x_col="X_coord",
    y_col="Y_coord",
    cluster_col="leiden_joint",
    he_alpha=0.9,  
    outline_alpha=0.9,
    outline_width=3.0,
    highlight_niches=['1','2'],
    figsize=(7, 7),
    dpi=300,
    savefig=f'../../../SpatialFusion/results/figures_Fig5/{sample_name}_outline_zoom_niche12.svg',
    show=True,
)


In [None]:
# === Parameters ===
xlim = (53000, 58000)
ylim = (8000, 13000)
point_size = 3
figsize = (5, 5)
save_prefix = f"../../../SpatialFusion/results/figures_Fig5/{sample_name}_androgen_activity_zoom"

# === Plot settings ===
plot_configs = [
    ("Androgen", "Androgen Activity", "vlag"),
 
]

# === Create subplots ===
fig, axes = plt.subplots(1, 1, figsize=figsize, sharex=True, sharey=True)
axes = [axes]

legends_info = []
colorbar_info = []  # store (ax, column, cmap) for continuous plots

# === Generate plots ===
for ax, (col, title, pal) in zip(axes, plot_configs):
    plot_annotation_pathway(adatas[sample_name], ax, col, title, pal, xlim=xlim,
    ylim=ylim, vmin=0, vmax=3, cluster_col='leiden_joint',
                            palette_niches=palette_dict_1['leiden_joint'], highlight_niches=['1','2'])

plt.tight_layout()
plt.show()

fig.savefig(f"{save_prefix}_panel.svg", dpi=150, bbox_inches="tight", transparent=True)
print(f"✅ Saved panel: {save_prefix}_panel.svg")

In [None]:
# === Parameters ===
xlim = (53000, 58000)
ylim = (8000, 13000)
point_size = 3
figsize = (5, 5)
save_prefix = f"../../../SpatialFusion/results/figures_Fig5/{sample_name}_JAK-STAT_activity_zoom"

# === Plot settings ===
plot_configs = [
    ("JAK-STAT", "JAK-STAT Activity", "vlag"),
 
]

# === Create subplots ===
fig, axes = plt.subplots(1, 1, figsize=figsize, sharex=True, sharey=True)
axes = [axes]

legends_info = []
colorbar_info = []  # store (ax, column, cmap) for continuous plots

# === Generate plots ===
for ax, (col, title, pal) in zip(axes, plot_configs):
    plot_annotation_pathway(adatas[sample_name], ax, col, title, pal, xlim=xlim,
    ylim=ylim, vmin=0, vmax=3, cluster_col='leiden_joint',
                            palette_niches=palette_dict_1['leiden_joint'], highlight_niches=['1','2'])

plt.tight_layout()
plt.show()

fig.savefig(f"{save_prefix}_panel.svg", dpi=150, bbox_inches="tight", transparent=True)
print(f"✅ Saved panel: {save_prefix}_panel.svg")

In [None]:
sample_name = 'P5CRC'

with tifffile.TiffFile(source_image_paths[sample_name]) as tif:
    wsi = tif.series[0].asarray()

wsi.shape

In [None]:
# --- Region of interest in H&E pixel space ---
xlim = (60000, 65000)
ylim = (55000, 60000)
x0, x1 = xlim
y0, y1 = ylim

# --- Set up figure ---
sns.set_style("white")
sns.set_context("talk")

plot_spatial_with_HE_outline(
    wsi,
    adatas[sample_name],
    sample_name,
    palette_dict_1['leiden_joint'],
    xlim,
    ylim,
    x_col="X_coord",
    y_col="Y_coord",
    cluster_col="leiden_joint",
    he_alpha=0.9,  
    outline_alpha=0.9,
    outline_width=3.0,
    highlight_niches=['1','2'],
    figsize=(7, 7),
    dpi=300,
    savefig=f'../../../SpatialFusion/results/figures_Fig5/{sample_name}_outline_zoom_niche12.svg',
    show=True,
)


In [None]:
# === Parameters ===
xlim = (60000, 65000)
ylim = (55000, 60000)
point_size = 3
figsize = (5, 5)
save_prefix = f"../../../SpatialFusion/results/figures_Fig5/{sample_name}_androgen_activity_zoom"

# === Plot settings ===
plot_configs = [
    ("Androgen", "Androgen Activity", "vlag"),
 
]

# === Create subplots ===
fig, axes = plt.subplots(1, 1, figsize=figsize, sharex=True, sharey=True)
axes = [axes]

legends_info = []
colorbar_info = []  # store (ax, column, cmap) for continuous plots

# === Generate plots ===
for ax, (col, title, pal) in zip(axes, plot_configs):
    plot_annotation_pathway(adatas[sample_name], ax, col, title, pal, xlim=xlim,
    ylim=ylim, vmin=0, vmax=3, cluster_col='leiden_joint',
                            palette_niches=palette_dict_1['leiden_joint'], highlight_niches=['1','2'])

plt.tight_layout()
plt.show()

fig.savefig(f"{save_prefix}_panel.svg", dpi=150, bbox_inches="tight", transparent=True)
print(f"✅ Saved panel: {save_prefix}_panel.svg")

In [None]:
# === Parameters ===
xlim = (60000, 65000)
ylim = (55000, 60000)
point_size = 3
figsize = (5, 5)
save_prefix = f"../../../SpatialFusion/results/figures_Fig5/{sample_name}_JAK-STAT_activity_zoom"

# === Plot settings ===
plot_configs = [
    ("JAK-STAT", "JAK-STAT Activity", "vlag"),
 
]

# === Create subplots ===
fig, axes = plt.subplots(1, 1, figsize=figsize, sharex=True, sharey=True)
axes = [axes]

legends_info = []
colorbar_info = []  # store (ax, column, cmap) for continuous plots

# === Generate plots ===
for ax, (col, title, pal) in zip(axes, plot_configs):
    plot_annotation_pathway(adatas[sample_name], ax, col, title, pal, xlim=xlim,
    ylim=ylim, vmin=0, vmax=3, cluster_col='leiden_joint',
                            palette_niches=palette_dict_1['leiden_joint'], highlight_niches=['1','2'])

plt.tight_layout()
plt.show()

fig.savefig(f"{save_prefix}_panel.svg", dpi=150, bbox_inches="tight", transparent=True)
print(f"✅ Saved panel: {save_prefix}_panel.svg")

In [None]:
sample_name = 'P3NAT'

with tifffile.TiffFile(source_image_paths[sample_name]) as tif:
    wsi = tif.series[0].asarray()

wsi.shape

In [None]:
# --- Region of interest in H&E pixel space ---
xlim = (0, 5000)
ylim = (23000, 28000)
x0, x1 = xlim
y0, y1 = ylim

# --- Set up figure ---
sns.set_style("white")
sns.set_context("talk")

plot_spatial_with_HE_outline(
    wsi,
    adatas[sample_name],
    sample_name,
    palette_dict_1['leiden_joint'],
    xlim,
    ylim,
    x_col="X_coord",
    y_col="Y_coord",
    cluster_col="leiden_joint",
    he_alpha=0.9,  
    outline_alpha=0.9,
    outline_width=3.0,
    highlight_niches=['1','2'],
    figsize=(7, 7),
    dpi=300,
    savefig=f'../../../SpatialFusion/results/figures_Fig5/{sample_name}_outline_zoom_niche12.svg',
    show=True,
)


In [None]:
# === Parameters ===
xlim = (0, 5000)
ylim = (23000, 28000)
point_size = 3
figsize = (5, 5)
save_prefix = f"../../../SpatialFusion/results/figures_Fig5/{sample_name}_androgen_activity_zoom"

# === Plot settings ===
plot_configs = [
    ("Androgen", "Androgen Activity", "vlag"),
 
]

# === Create subplots ===
fig, axes = plt.subplots(1, 1, figsize=figsize, sharex=True, sharey=True)
axes = [axes]

legends_info = []
colorbar_info = []  # store (ax, column, cmap) for continuous plots

# === Generate plots ===
for ax, (col, title, pal) in zip(axes, plot_configs):
    plot_annotation_pathway(adatas[sample_name], ax, col, title, pal, xlim=xlim,
    ylim=ylim, vmin=0, vmax=3, cluster_col='leiden_joint',
                            palette_niches=palette_dict_1['leiden_joint'], highlight_niches=['1','2'])

plt.tight_layout()
plt.show()

fig.savefig(f"{save_prefix}_panel.svg", dpi=150, bbox_inches="tight", transparent=True)
print(f"✅ Saved panel: {save_prefix}_panel.svg")

In [None]:
# === Parameters ===
xlim = (0, 5000)
ylim = (23000, 28000)
point_size = 3
figsize = (5, 5)
save_prefix = f"../../../SpatialFusion/results/figures_Fig5/{sample_name}_JAK-STAT_activity_zoom"

# === Plot settings ===
plot_configs = [
    ("JAK-STAT", "JAK-STAT Activity", "vlag"),
 
]

# === Create subplots ===
fig, axes = plt.subplots(1, 1, figsize=figsize, sharex=True, sharey=True)
axes = [axes]

legends_info = []
colorbar_info = []  # store (ax, column, cmap) for continuous plots

# === Generate plots ===
for ax, (col, title, pal) in zip(axes, plot_configs):
    plot_annotation_pathway(adatas[sample_name], ax, col, title, pal, xlim=xlim,
    ylim=ylim, vmin=0, vmax=3, cluster_col='leiden_joint',
                            palette_niches=palette_dict_1['leiden_joint'], highlight_niches=['1','2'])

plt.tight_layout()
plt.show()

fig.savefig(f"{save_prefix}_panel.svg", dpi=150, bbox_inches="tight", transparent=True)
print(f"✅ Saved panel: {save_prefix}_panel.svg")

In [None]:
sample_name = 'P5NAT'

with tifffile.TiffFile(source_image_paths[sample_name]) as tif:
    wsi = tif.series[0].asarray()

wsi.shape

In [None]:
# --- Region of interest in H&E pixel space ---
xlim = (48000, 53000)
ylim = (33000, 38000)
x0, x1 = xlim
y0, y1 = ylim

# --- Set up figure ---
sns.set_style("white")
sns.set_context("talk")

plot_spatial_with_HE_outline(
    wsi,
    adatas[sample_name],
    sample_name,
    palette_dict_1['leiden_joint'],
    xlim,
    ylim,
    x_col="X_coord",
    y_col="Y_coord",
    cluster_col="leiden_joint",
    he_alpha=0.9,  
    outline_alpha=0.9,
    outline_width=3.0,
    highlight_niches=['1','2'],
    figsize=(7, 7),
    dpi=300,
    savefig=f'../../../SpatialFusion/results/figures_Fig5/{sample_name}_outline_zoom_niche12.svg',
    show=True,
)


In [None]:
# === Parameters ===
xlim = (48000, 53000)
ylim = (33000, 38000)
point_size = 3
figsize = (5, 5)
save_prefix = f"../../../SpatialFusion/results/figures_Fig5/{sample_name}_androgen_activity_zoom"

# === Plot settings ===
plot_configs = [
    ("Androgen", "Androgen Activity", "vlag"),
 
]

# === Create subplots ===
fig, axes = plt.subplots(1, 1, figsize=figsize, sharex=True, sharey=True)
axes = [axes]

legends_info = []
colorbar_info = []  # store (ax, column, cmap) for continuous plots

# === Generate plots ===
for ax, (col, title, pal) in zip(axes, plot_configs):
    plot_annotation_pathway(adatas[sample_name], ax, col, title, pal, xlim=xlim,
    ylim=ylim, vmin=0, vmax=3, cluster_col='leiden_joint',
                            palette_niches=palette_dict_1['leiden_joint'], highlight_niches=['1','2'])

plt.tight_layout()
plt.show()

fig.savefig(f"{save_prefix}_panel.svg", dpi=150, bbox_inches="tight", transparent=True)
print(f"✅ Saved panel: {save_prefix}_panel.svg")

In [None]:
# === Parameters ===
xlim = (48000, 53000)
ylim = (33000, 38000)
point_size = 3
figsize = (5, 5)
save_prefix = f"../../../SpatialFusion/results/figures_Fig5/{sample_name}_JAK-STAT_activity_zoom"

# === Plot settings ===
plot_configs = [
    ("JAK-STAT", "JAK-STAT Activity", "vlag"),
 
]

# === Create subplots ===
fig, axes = plt.subplots(1, 1, figsize=figsize, sharex=True, sharey=True)
axes = [axes]

legends_info = []
colorbar_info = []  # store (ax, column, cmap) for continuous plots

# === Generate plots ===
for ax, (col, title, pal) in zip(axes, plot_configs):
    plot_annotation_pathway(adatas[sample_name], ax, col, title, pal, xlim=xlim,
    ylim=ylim, vmin=0, vmax=3, cluster_col='leiden_joint',
                            palette_niches=palette_dict_1['leiden_joint'], highlight_niches=['1','2'])

plt.tight_layout()
plt.show()

fig.savefig(f"{save_prefix}_panel.svg", dpi=150, bbox_inches="tight", transparent=True)
print(f"✅ Saved panel: {save_prefix}_panel.svg")

# Niche 1 and 2

## Functions

In [None]:
import scanpy as sc
import matplotlib.pyplot as plt
import numpy as np
import os

def plot_umap_nature_style(
    adata,
    color_vars=["leiden", "condition", "refined_cellsubtypes", "sample_id"],
    palettes=None,
    ncols=2,
    point_size=8,
    cmap="viridis",
    frameon=False,
    figsize=(10, 10),
    title_fontsize=11,
    label_fontsize=9,
    shuffle=True,
    random_state=42,
    save_path=None,         # <--- single output SVG/PDF path
    rasterized=True,        # <--- rasterize only scatter points
):
    """
    Nature Genetics–style multi-panel UMAP figure.
    Combines multiple UMAPs into a single grid layout and rasterizes points
    while keeping vector graphics for axes/labels.

    Parameters
    ----------
    adata : AnnData
        The annotated data matrix.
    color_vars : list of str
        Variables to color by (each will be a subplot).
    palettes : dict
        Optional dict mapping color_var -> palette dict or list.
    ncols : int
        Number of columns in the subplot grid.
    save_path : str
        Output path to save the combined figure (e.g., 'umap_panels.svg').
    rasterized : bool
        If True, only the scatter points are rasterized to reduce file size.
    """

    # -------------------------------
    # Global style configuration
    # -------------------------------
    sc.set_figure_params(
        dpi=200,
        dpi_save=300,
        fontsize=9,
        facecolor="white",
        figsize=figsize,
    )

    plt.rcParams.update({
        "axes.edgecolor": "black",
        "axes.linewidth": 0.6,
        "axes.spines.right": False,
        "axes.spines.top": False,
        "axes.titlesize": title_fontsize,
        "axes.labelsize": label_fontsize,
        "xtick.labelsize": 7,
        "ytick.labelsize": 7,
    })

    # Shuffle order of cells to avoid overlay bias
    if shuffle:
        rng = np.random.default_rng(random_state)
        idx = rng.permutation(adata.n_obs)
    else:
        idx = np.arange(adata.n_obs)
    adata_shuffled = adata[idx, :].copy()

    # -------------------------------
    # Layout setup
    # -------------------------------
    n_panels = len(color_vars)
    nrows = int(np.ceil(n_panels / ncols))
    fig, axes = plt.subplots(nrows, ncols, figsize=figsize, squeeze=False)
    axes = axes.flatten()

    # -------------------------------
    # Plot each UMAP panel
    # -------------------------------
    for i, c in enumerate(color_vars):
        ax = axes[i]
        pal = palettes[c] if palettes and c in palettes else None

        # Use Scanpy scatter with matplotlib ax to control rasterization
        sc.pl.umap(
            adata_shuffled,
            color=c,
            size=point_size,
            palette=pal,
            cmap=cmap if pal is None else None,
            frameon=frameon,
            ax=ax,
            show=False,
            title=c,
        )

        # Rasterize only the scatter artists (points)
        if rasterized:
            for coll in ax.collections:
                coll.set_rasterized(True)

        ax.set_xlabel("UMAP1")
        ax.set_ylabel("UMAP2")

    # Remove empty subplots if needed
    for j in range(n_panels, len(axes)):
        axes[j].set_visible(False)

    plt.tight_layout(w_pad=0.5, h_pad=0.7)

    # -------------------------------
    # Save figure
    # -------------------------------
    if save_path is not None:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        fig.savefig(save_path, dpi=300, bbox_inches="tight", transparent=True)
        print(f"✅ Saved combined figure: {save_path}")

    plt.show()
    return fig


In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.path import Path
from matplotlib.patches import PathPatch, Rectangle
from typing import Optional, Tuple, Dict, Sequence
import os

def _ribbon_path(x0, x1, y0_bot, y0_top, y1_bot, y1_top, curvature=0.5, steps=40):
    xs = np.linspace(x0, x1, steps)
    t = np.linspace(0, 1, steps)
    s = 3*t**2 - 2*t**3
    s = (1 - curvature) * t + curvature * s
    top = (1 - s) * y0_top + s * y1_top
    bot = (1 - s) * y0_bot + s * y1_bot
    verts = np.concatenate([np.column_stack([xs, bot]), np.column_stack([xs[::-1], top[::-1]])])
    codes = np.array([Path.MOVETO] + [Path.LINETO]*(len(xs)-1) + [Path.LINETO]*len(xs) + [Path.CLOSEPOLY])
    verts = np.vstack([verts, verts[0]])
    return Path(verts, codes)

def alluvial_two_groups(
    data: pd.DataFrame,
    niche_col: str = "niche",
    subtype_col: str = "cell_subtype",
    weight_col: Optional[str] = None,
    normalize: bool = True,
    min_frac_to_label: float = 0.06,
    title: Optional[str] = None,
    outfile_prefix: Optional[str] = None,
    palette: Optional[Dict[str, str]] = None,     
    subtype_order: Optional[Sequence[str]] = None, 
    niche_order: Optional[Sequence[str]] = None,
    figsize=(5,3),
) -> Tuple[plt.Figure, plt.Axes, pd.DataFrame, pd.DataFrame]:
    niches = list(data[niche_col].astype(str).unique())
    if len(niches) != 2:
        raise ValueError(f"Expected exactly 2 niches, found {len(niches)}: {niches}")

    if niche_order is not None:
        # Validate that all requested niches exist
        missing = [n for n in niche_order if n not in niches]
        if missing:
            raise ValueError(f"These niches not found in data: {missing}")
        niches = list(niche_order)
    else:
        niches = list(niches)  # use natural order

    if weight_col is None:
        data = data.copy()
        data["_w"] = 1.0
        weight_col = "_w"

    counts = (
        data.groupby([niche_col, subtype_col])[weight_col].sum()
            .unstack(fill_value=0.0)
            .loc[niches]
    )

    # Determine subtype order
    if subtype_order is not None:
        # Keep only subtypes that actually appear; preserve order from subtype_order
        present = [s for s in subtype_order if s in counts.columns]
        # Add any missing (rare) subtypes at the end
        missing = [s for s in counts.columns if s not in present]
        ordered_cols = present + missing
    else:
        totals = counts.sum(axis=1)
        props = counts.div(totals.values[:, None]).fillna(0.0)
        order_scores = props.mean(axis=0) if normalize else counts.mean(axis=0)
        ordered_cols = order_scores.sort_values(ascending=False).index.tolist()

    counts = counts[ordered_cols]
    totals = counts.sum(axis=1)
    props = counts.div(totals.values[:, None]).fillna(0.0)

    # Plot
    fig, ax = plt.subplots(figsize=figsize, dpi=150)
    x0, x1 = 0.0, 1.0
    bar_width = 0.16
    gap = 0.04

    # Bar segment boundaries
    y_positions_0, y_positions_1 = {}, {}
    y0 = 0.0
    for s in counts.columns:
        h = props.iloc[0][s] if normalize else counts.iloc[0][s] / totals.iloc[0]
        y_positions_0[s] = (y0, y0 + h); y0 += h
    y1 = 0.0
    for s in counts.columns:
        h = props.iloc[1][s] if normalize else counts.iloc[1][s] / totals.iloc[1]
        y_positions_1[s] = (y1, y1 + h); y1 += h

    # Bars
    ax.add_patch(Rectangle((x0 - bar_width - gap, 0), bar_width, 1.0, fill=False))
    ax.add_patch(Rectangle((x1 + gap, 0), bar_width, 1.0, fill=False))

    # Colors
    subtype_colors = {}
    for s in counts.columns:
        if palette is not None and s in palette:
            subtype_colors[s] = palette[s]
        else:
            subtype_colors[s] = ax._get_lines.get_next_color()

    # Draw segments
    for s in counts.columns:
        c = subtype_colors[s]
        yb, yt = y_positions_0[s]
        ax.add_patch(Rectangle((x0 - bar_width - gap, yb), bar_width, yt - yb, facecolor=c, edgecolor='none'))
        yb, yt = y_positions_1[s]
        ax.add_patch(Rectangle((x1 + gap, yb), bar_width, yt - yb, facecolor=c, edgecolor='none'))

    # Ribbons
    for s in counts.columns:
        c = subtype_colors[s]
        y0b, y0t = y_positions_0[s]
        y1b, y1t = y_positions_1[s]
        path = _ribbon_path(x0, x1, y0b, y0t, y1b, y1t, curvature=0.7, steps=60)
        ax.add_patch(PathPatch(path, facecolor=c, alpha=0.6, edgecolor='none'))

    # Cosmetics
    ax.set_xlim(x0 - bar_width - 2*gap, x1 + bar_width + 2*gap)
    ax.set_ylim(0, 1)
    ax.set_xticks([x0, x1])
    ax.set_xticklabels([niches[0], niches[1]])
    ax.set_ylabel("Fraction" if normalize else "Normalized height")
    if title is None:
        title = "Alluvial plot of cell subtype composition between niches"
    ax.set_title(title)

    # Labels on sizable segments
    for side, y_positions, idx in [(x0, y_positions_0, 0), (x1, y_positions_1, 1)]:
        for s in counts.columns:
            yb, yt = y_positions[s]
            frac = yt - yb
            if frac >= min_frac_to_label:
                ax.text(side, yb + frac/2, f"{frac*100:.0f}%", ha='center', va='center', fontsize=8)

    # Legend
    handles = [Rectangle((0,0),1,1, facecolor=subtype_colors[s], edgecolor='none') for s in counts.columns]
    ax.legend(handles, counts.columns.tolist(), title="Cell subtype",
              bbox_to_anchor=(1.02, 1), loc="upper left", borderaxespad=0.)
    ax.grid(visible=False)

    fig.tight_layout()
    if outfile_prefix is not None:
        fig.savefig(outfile_prefix, bbox_inches='tight', dpi=200)


    return fig, ax, counts, props


In [None]:
import pandas as pd
import numpy as np
import scanpy as sc
import matplotlib.pyplot as plt

def plot_DEG_dotplot_nature(
    adata,
    niche1_genes,
    niche2_genes,
    gene_group,
    is_contaminant,
    group_labels=("Niche 1", "Niche 2"),
    remove_contaminants=True,
    cmap="RdBu_r",
    save_path=None,
    vmin=-0.5, vmax=0.5,
):
    """
    Create a dotplot of DEGs between two niches.
    Removes contaminants and groups genes by function.
    """

    # -------------------------------
    # 1. Remove contaminant genes
    # -------------------------------

    if remove_contaminants:
        niche1_genes = [g for g in niche1_genes if not is_contaminant(g)]
        niche2_genes = [g for g in niche2_genes if not is_contaminant(g)]

    # -------------------------------
    # 2. Annotate functional categories
    # -------------------------------

    df = pd.DataFrame({
        "gene": niche1_genes + niche2_genes,
        "niche": [group_labels[0]] * len(niche1_genes) + [group_labels[1]] * len(niche2_genes)
    })
    df["function"] = df["gene"].apply(gene_group)
    # Filter out genes not assigned to any known functional group
    df = df[df["function"] != "Other"].copy()

    # -------------------------------
    # 3. Sort genes by function & niche
    # -------------------------------
    df = df.sort_values(["niche", "function",  "gene"])
    grouped_genes = df["gene"].tolist()

    # -------------------------------
    # 4. Create dotplot using Scanpy
    # -------------------------------
    sc.set_figure_params(dpi=200, dpi_save=300, facecolor="white")

    dp = sc.pl.dotplot(
        adata,
        var_names=grouped_genes,
        groupby="leiden_joint",  # must exist in adata.obs
        vmin=vmin, vmax=vmax,
        cmap=cmap,
        dot_max=0.6,
        dot_min=0.05,
        show=False,
        return_fig=True,
        figsize=(9, 2),
    )

    fig = dp

    # -------------------------------
    # 5. Optional save
    # -------------------------------
    if save_path:
        fig.savefig(save_path, bbox_inches="tight", dpi=300)
        print(f"✅ Saved: {save_path}")

    plt.show()
    return df


## How do epithelium change?

In [None]:
subadata = adata[(adata.obs.leiden.isin(['1','2'])) & (adata.obs.refined_celltypes.isin(['Epithelial']))].copy()

In [None]:
subadata.obs['condition'] = subadata.obs.sample_id.str[-3:]

sc.tl.pca(subadata)

sc.external.pp.bbknn(subadata, batch_key='sample_id')

In [None]:
sc.tl.umap(subadata)

In [None]:
subtype_dict = {
    # Secretory lineage
    "Epithelial—Goblet/secretory": "Epithelial—Secretory",
    "Epithelial—crypt/secretory": "Epithelial—Secretory",

    # Progenitor / regenerative / undifferentiated
    "Epithelial—Stem/TA": "Epithelial—Progenitor/Regenerative",
    "Epithelial—TA (pre-absorptive)": "Epithelial—Progenitor/Regenerative",
    "Epithelial—injury/regenerative (LCN2+)": "Epithelial—Progenitor/Regenerative",
    "Epithelial—unspecified": "Epithelial—Progenitor/Regenerative",
    "Epithelial—low-grade/dysplastic": "Epithelial—Progenitor/Regenerative",

    # Mature absorptive lineage (colonocytes)
    "Epithelial—Mature colonocyte (CA+)": "Epithelial—Mature colonocyte",
    "Epithelial—Mature colonocyte I": "Epithelial—Mature colonocyte",
    "Epithelial—Mature colonocyte II": "Epithelial—Mature colonocyte",
    "Epithelial—Mature colonocyte (absorptive)": "Epithelial—Mature colonocyte",
}


In [None]:
subadata.obs['Cell Subtype'] = subadata.obs.refined_cellsubtypes.replace(subtype_dict)

In [None]:
palette_specs = {
            'Cell Subtype': tab_filtered,
            'condition': tab_filtered,
            'sample_id': tab_filtered,
        }

palette_dict_2 = build_palettes_from_adata(subadata, palette_specs)

In [None]:
plot_umap_nature_style(
    subadata,
    color_vars=["leiden_joint", "condition", "Cell Subtype", "sample_id"],
    palettes={'leiden_joint': palette_dict_1['leiden_joint'],
              "condition": palette_dict_2["condition"],
              "Cell Subtype": palette_dict_2["Cell Subtype"],
              "sample_id": palette_dict_2[ "sample_id"]},
    ncols=2,
    point_size=2,
    cmap="viridis",
    frameon=False,
    figsize=(5, 3),
    title_fontsize=9,
    label_fontsize=9,
    save_path='../../../SpatialFusion/results/figures_Fig5/epithelial_comparison_umap.svg',
)

In [None]:
subadata.obs['Cell Subtype'].value_counts()

In [None]:
avg_expr = pd.Series(np.asarray(subadata.X.mean(axis=0)).ravel(), index=subadata.var_names)

hex_subadata = subadata[:,avg_expr[avg_expr>0.5].index].copy()

In [None]:
sc.tl.rank_genes_groups(hex_subadata, groupby='leiden', method='wilcoxon')

In [None]:
dgex = {}
for gr in hex_subadata.obs.leiden.unique():
    dgex[gr] = sc.get.rank_genes_groups_df(hex_subadata, group=gr)

## Plot DGEX

In [None]:
tmp = subadata.copy()

sc.pp.scale(tmp)

In [None]:
def gene_group(g):
    g = g.upper()
    if any(k in g for k in ["MUC", "TFF", "FCGBP", "ZG16", "SPINK"]):
        return "Goblet / secretory"
    if any(k in g for k in ["PLA2", "DUOX", "REG", "LCN", "PI3", "SLC", "ASS"]):
        return "Antimicrobial / transport"
    if any(k in g for k in ["CEACAM", "IGH", "IG", "SERPIN"]):
        return "Immune / inflammatory"
    if any(k in g for k in ["MT-", "ECH", "ATP", "UQCR", "EEF", "HMGCS2", "ECHS"]):
        return "Mitochondrial / metabolic"
    if any(k in g for k in ["TSPAN", "KLF", "ELF", "EGR", "CD24", "BSG"]):
        return "Epithelial / regulatory"
    return "Other"

def is_contaminant(g):
    g = g.upper()
    return (
        g.startswith("MT-")
        or g.startswith("RPS")
        or g.startswith("RPL")
        or g.startswith("IG")
        or g.startswith("HIST")
    )

In [None]:
# Your differential gene lists
niche2_genes = dgex['2'][dgex['2']['logfoldchanges']>0].sort_values('logfoldchanges', ascending=False).head(25).names.ravel()
niche1_genes = dgex['1'][dgex['1']['logfoldchanges']>0].sort_values('logfoldchanges', ascending=False).head(25).names.ravel()


deg_df = plot_DEG_dotplot_nature(
    adata=tmp,
    niche1_genes=niche1_genes,
    niche2_genes=niche2_genes,
    gene_group=gene_group,
    is_contaminant=is_contaminant,
    group_labels=("Niche 1", "Niche 2"),
    save_path="../../../SpatialFusion/results/figures_Fig5/dotplot_epithelial_deg.svg",
    vmin=-0.25, vmax=0.25,
)


In [None]:
deg_df

In [None]:
subdf =subadata.obs.copy()
# --- Run demo if user df not present ---
fig, ax, counts_df, props_df = alluvial_two_groups(
    subdf,
    niche_col="leiden_joint",
    subtype_col="Cell Subtype",
    palette=palette_dict_2["Cell Subtype"],
    normalize=True,
    title="Epithelial cell composition",
    niche_order=['1','2'],
    figsize=(6.5,2),
    outfile_prefix="../../../SpatialFusion/results/figures_Fig5/epithelial_alluvial.svg",
)

## How do lymphoid cells change?

In [None]:
lymphadata = adata[(adata.obs.leiden.isin(['1','2'])) & (adata.obs.refined_celltypes.isin(['Lymphoid']))].copy()

In [None]:
lymphadata.obs['condition'] = lymphadata.obs.sample_id.str[-3:]

sc.tl.pca(lymphadata)

sc.external.pp.bbknn(lymphadata, batch_key='sample_id')

sc.tl.umap(lymphadata)

In [None]:
celltype_groups = {
    # B lineage
    "B lineage—Plasma cell": "Plasma",
    "B lineage—Plasma cell (activated)": "Plasma",
    "B lineage—activated B / GC-like": "B cell",
    "B lineage—B cell": "B cell",

    # T lineage
    "T cell—naive/Tfh": "T cell",
    "T cell—unspecified": "T cell",
}


In [None]:
lymphadata.obs['Cell Subtype'] = lymphadata.obs.refined_cellsubtypes.replace(celltype_groups)

In [None]:
palette_dict_3 = build_palettes_from_adata(lymphadata, palette_specs)

In [None]:
plot_umap_nature_style(
    lymphadata,
    color_vars=["leiden_joint", "condition", "Cell Subtype", "sample_id"],
    palettes={'leiden_joint': palette_dict_1['leiden_joint'],
              "condition": palette_dict_2["condition"],
              "Cell Subtype": palette_dict_3["Cell Subtype"],
              "sample_id": palette_dict_2[ "sample_id"]},
    ncols=2,
    point_size=2,
    cmap="viridis",
    frameon=False,
    figsize=(5, 3),
    title_fontsize=9,
    label_fontsize=9,
    save_path='../../../SpatialFusion/results/figures_Fig5/lymphoid_comparison_umap.svg',
)

We subset to B cells to get a better representation

In [None]:
avg_expr = pd.Series(np.asarray(lymphadata.X.mean(axis=0)).ravel(), index=adata.var_names)

hex_lymphadata = lymphadata[:,avg_expr[avg_expr>0.5].index].copy()

In [None]:
sc.tl.rank_genes_groups(hex_lymphadata, groupby='leiden_joint', method='wilcoxon')

In [None]:
dgex_lymphoid = {}
for gr in hex_lymphadata.obs.leiden.unique():
    dgex_lymphoid[gr] = sc.get.rank_genes_groups_df(hex_lymphadata, group=gr)

In [None]:
dgex_lymphoid['2'].sort_values('logfoldchanges', ascending=False).head(15)

In [None]:
dgex_lymphoid['2'].head(50).names.ravel()

In [None]:
dgex_lymphoid['1'].head(50).names.ravel()

In [None]:
def gene_group(g):
    g = g.upper()

    # --- Secretory / ER expansion / antibody production ---
    if any(k in g for k in [
        "XBP1", "MZB1", "DERL3", "TXNDC5", "SEL1L3", "TXNDC11",
        "IGHM", "IGHA1", "IGHG1", "IGKC", "JCHAIN", "PIGR",
        "RRBP1", "SEC11C", "SSR2", "SSR4"
    ]):
        return "Secretory / ER expansion"

    # --- Stress / activation / immediate early response ---
    if any(k in g for k in [
        "FOS", "JUN", "DUSP1", "ZFP36", "ZFP36L2", "TXNIP", "DDIT4",
        "TSC22D3", "BTG2", "PDCD4", "EGR1", "JUNB", "JUND"
    ]):
        return "Stress / activation"

    # --- Mitochondrial / oxidative metabolism ---
    if any(k in g for k in [
        "MT-ND4", "MT-ND5", "MT-CO2", "MT-ATP6", "FTH1"
    ]):
        return "Mitochondrial / oxidative"

    # --- If it doesn’t match any group ---
    return "Other"


def is_contaminant(g):
    g = g.upper()
    return (
        g.startswith("RPS")
        or g.startswith("RPL")
        or g.startswith("HIST")
        or any(k in g for k in [
            # endothelial / stromal genes
            "VWF", "PECAM", "COL", "FN1", "LYVE1", "TEK",
            "ASPN", "MFAP", "TFPI2", "EDNRA", "NOTCH3", "RAMP3"
        ])
    )

In [None]:
tmp = hex_lymphadata.copy()

sc.pp.scale(tmp)

In [None]:
# Your differential gene lists
niche2_genes = dgex_lymphoid['2'][dgex_lymphoid['2']['logfoldchanges']>0].sort_values('logfoldchanges', ascending=False).head(25).names.ravel()
niche1_genes = dgex_lymphoid['1'][dgex_lymphoid['1']['logfoldchanges']>0].sort_values('logfoldchanges', ascending=False).head(25).names.ravel()
deg_df = plot_DEG_dotplot_nature(
    adata=tmp,
    niche1_genes=niche1_genes,
    niche2_genes=niche2_genes,
    gene_group=gene_group,
    is_contaminant=is_contaminant,
    group_labels=("Niche 1", "Niche 2"),
    save_path="../../../SpatialFusion/results/figures_Fig5/dotplot_lymphoid_deg.svg",
    vmin=-0.25, vmax=0.25,
)


In [None]:
deg_df

In [None]:
subdf =lymphadata.obs.copy()
# --- Run demo if user df not present ---
fig, ax, counts_df, props_df = alluvial_two_groups(
    subdf,
    niche_col="leiden_joint",
    subtype_col="Cell Subtype",
    palette=palette_dict_2["Cell Subtype"],
    normalize=True,
    title="Lymphoid cell composition",
    niche_order=['1','2'],
    figsize=(4,2),
    outfile_prefix="../../../SpatialFusion/results/figures_Fig5/lymphoid_alluvial.svg",
)