# Functions

In [None]:
import matplotlib
matplotlib.rcParams['svg.fonttype'] = 'none'

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.patches as mpatches
import pathlib as pl
import scanpy as sc
import pandas as pd
import numpy as np
from tqdm import tqdm

def group_small_clusters(
    df: pd.DataFrame,
    cluster_col: str,
    min_count: int = 1000,
    new_label: str = "small_clusters",
    output_col: str = None
) -> pd.DataFrame:
    """
    Groups small clusters in a DataFrame column into a single label.

    Parameters:
        df (pd.DataFrame): Input DataFrame containing cluster labels.
        cluster_col (str): Name of the column containing cluster labels (e.g., 'leiden').
        min_count (int): Minimum number of entries a cluster must have to avoid grouping.
        new_label (str): Label to assign to small clusters.
        output_col (str or None): Name of the new column to store grouped labels. 
                                  If None, defaults to '{cluster_col}_grouped'.

    Returns:
        pd.DataFrame: A copy of the DataFrame with a new column containing grouped cluster labels.
    """
    if cluster_col not in df.columns:
        raise ValueError(f"Column '{cluster_col}' not found in DataFrame.")

    output_col = output_col or f"{cluster_col}_grouped"
    cluster_counts = df[cluster_col].value_counts()
    small_clusters = cluster_counts[cluster_counts < min_count].index

    new_df = df.copy()
    new_df[output_col] = df[cluster_col].astype(str)
    new_df.loc[df[cluster_col].isin(small_clusters), output_col] = new_label

    return new_df[output_col]


In [None]:
def build_palettes_from_adata(adata, palette_specs):
    """
    Build labeled color palettes for categorical columns in adata.obs.

    Parameters
    ----------
    adata : AnnData
        Must have .obs DataFrame containing categorical columns.
    palette_specs : dict
        Mapping {column_name: palette} where palette can be:
          - a string palette name (e.g. "tab10")
          - a list of RGB colors (custom)

    Returns
    -------
    dict
        {column_name: {label: color}} mapping.
    """
    custom_palettes = {}

    for col, palette in palette_specs.items():
        if col not in adata.obs.columns:
            print(f"⚠️ Warning: '{col}' not found in adata.obs — skipping.")
            continue

        unique_vals = sorted(adata.obs[col].astype(str).dropna().unique())
        n_unique = len(unique_vals)
        print(n_unique)

        # If user passed a name → generate via seaborn
        if isinstance(palette, str):
            pal_colors = sns.color_palette(palette, n_colors=n_unique)
        # If user passed a list → use directly
        elif isinstance(palette, (list, tuple)):
            pal_colors = palette[:n_unique]
        else:
            raise ValueError(f"Unsupported palette type for '{col}': {type(palette)}")

        color_dict = dict(zip(unique_vals, pal_colors))
        custom_palettes[col] = color_dict

    print(f"✅ Built palettes for {len(custom_palettes)} columns.")
    return custom_palettes


def plot_celltype_spatial_single_split_legend(
    df,
    color_by="celltype",
    sample_id=None,
    title=None,
    palette_dict=None,         # ✅ added
    palette_name="tab20",
    s=1.5,
    save_svg=True,
    output_prefix="spatial_plot",
    legend_title=None,
):
    """
    Nature Genetics–style spatial scatterplot for one sample,
    saving main plot as PNG (raster) and legend separately as SVG (vector).
    """
    sns.set_style("white")
    sns.set_context("talk")

    # --- Subset one sample ---
    if sample_id is not None:
        df = df[df["sample_id"] == sample_id].copy()
        if df.empty:
            raise ValueError(f"Sample ID '{sample_id}' not found in DataFrame.")

    # --- Colors ---
    unique_labels = sorted(df[color_by].dropna().unique())
    if palette_dict is not None and color_by in palette_dict:
        print('Using provided color palette.')
        color_dict = palette_dict[color_by]
    else:
        print('Generating color palette.')
        palette = sns.color_palette(palette_name, n_colors=len(unique_labels))
        color_dict = dict(zip(unique_labels, palette))

    # --- Main plot ---
    fig, ax = plt.subplots(figsize=(6, 5), dpi=300)
    sns.scatterplot(
        data=df,
        x="X_coord", y="Y_coord",
        hue=color_by, palette=color_dict,
        s=s, alpha=0.9, linewidth=0,
        rasterized=True, ax=ax, legend=False
    )
    ax.invert_yaxis(); ax.set_aspect("equal", adjustable="box")
    for spine in ["top", "right", "left", "bottom"]:
        ax.spines[spine].set_visible(False)
    ax.set_xticks([]); ax.set_yticks([])
    ax.set_xlabel(""); ax.set_ylabel("")
    plt.tight_layout()

    # --- Save main figure ---
    fname_main = f"{output_prefix}_{sample_id or 'sample'}_main.png"
    fig.savefig(fname_main, dpi=300, bbox_inches="tight", transparent=True, format="png")
    print(f"Saved main figure: {fname_main}")

    # --- Legend ---
    fig_leg, ax_leg = plt.subplots(figsize=(3, 0.5 * len(unique_labels)), dpi=300)
    handles = [
        plt.Line2D([0], [0], marker='o', color='none', label=label,
                   markerfacecolor=color_dict[label], markersize=8)
        for label in unique_labels
    ]
    ax_leg.legend(handles=handles, loc="center left", frameon=False,
                  title=legend_title or color_by, title_fontsize=14, fontsize=14)
    ax_leg.axis("off")
    plt.tight_layout()

    if save_svg:
        fname_leg = f"{output_prefix}_{sample_id or 'sample'}_legend.svg"
        fig_leg.savefig(fname_leg, dpi=300, bbox_inches="tight", transparent=True, format="svg")
        print(f"Saved legend: {fname_leg}")

    plt.close(fig); plt.close(fig_leg)


# Data

In [None]:
import pathlib as pl
import pandas as pd
import scanpy as sc
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

## Clean clinical info

In [None]:
sample_list = ['LIB-064885st1', 'LIB-064886st1', 'LIB-064887st1', 'LIB-064889st1',
       'LIB-064890st1', 'LIB-065290st1', 'LIB-065291st1', 'LIB-065292st1',
       'LIB-065294st1', 'LIB-065295st1']

In [None]:
basic_clinical = pd.read_csv('../../../Broad_SpatialFoundation/VisiumHD-LUAD/clinical-info/basic-clinical-Novartis.csv').set_index('PID')

basic_clinical = basic_clinical[['Library','Block_ID','Histology','Age','Gender']]

basic_clinical = basic_clinical.loc[basic_clinical.Library.isin(sample_list)]

tumor_class = pd.read_csv('../../../Broad_SpatialFoundation/VisiumHD-LUAD/clinical-info/classification-tumor-novartis.csv').set_index('case no')

tumor_class = tumor_class.loc[basic_clinical.index,['pTNM T', 'pTNM N', 'pTNM  V', 'stage', 'grading']]

tumor_class['pTNM T red'] = tumor_class['pTNM T'].replace({'pT3': 'T3/4', 'pT2b': 'T1/2', 'pT1b': 'T1/2', 'pT4': 'T3/4',})

tumor_class['pTNM N red'] = tumor_class['pTNM N'].replace({'pN0': 'pN0', 'pN1': 'pN1/2/3', 'pN2': 'pN1/2/3', 'pN3': 'pN1/2/3',})

tumor_class['stage red'] = tumor_class['stage'].replace({'IIB': 'I/II', 'IVB': 'III/IV', 'III A': 'III/IV', 'III C': 'III/IV', 'IA': 'I/II',})

clinical_info =  pd.concat([basic_clinical, tumor_class], axis=1)

medication = pd.read_csv('../../../Broad_SpatialFoundation/VisiumHD-LUAD/clinical-info/current-medication-novartis.csv')

# 1. Drop duplicates if any (same case–agent combination repeated)
df_unique = medication.drop_duplicates(subset=['case no', 'agent'])

# 2. Create a binary indicator for each row
df_unique['received'] = 1

# 3. Pivot wider: one column per agent
df_wide = (
    df_unique
    .pivot_table(index='case no', 
                 columns='agent', 
                 values='received', 
                 fill_value=0, 
                 aggfunc='max')  # in case multiple entries
)

# 4. (Optional) flatten the column index if needed
df_wide = df_wide.reset_index()

df_wide = df_wide.set_index('case no').loc[clinical_info.index]


cts = df_wide.sum(axis=0)
tokeep = cts[(cts>=2) & (cts<=8)].index
df_wide = df_wide[tokeep]

clinical_info = pd.concat([clinical_info, df_wide], axis=1)

previous_disease = pd.read_csv('../../../Broad_SpatialFoundation/VisiumHD-LUAD/clinical-info/previous-diseases-novartis.csv')

# 1. Drop duplicates if any (same case–agent combination repeated)
df_unique = previous_disease.drop_duplicates(subset=['case no', 'description'])

# 2. Create a binary indicator for each row
df_unique['received'] = 1

# 3. Pivot wider: one column per agent
df_wide = (
    df_unique
    .pivot_table(index='case no', 
                 columns='description', 
                 values='received', 
                 fill_value=0, 
                 aggfunc='max')  # in case multiple entries
)

# 4. (Optional) flatten the column index if needed
df_wide = df_wide.reset_index()

df_wide = df_wide.set_index('case no').loc[clinical_info.index]


cts = df_wide.sum(axis=0)
tokeep = cts[(cts>=2) & (cts<=8)].index
df_wide = df_wide[tokeep]

clinical_info = pd.concat([clinical_info, df_wide], axis=1)

tsize = pd.read_csv('../../../Broad_SpatialFoundation/VisiumHD-LUAD/clinical-info/tumor-size-ischemia-novartis.csv').set_index('case no')

clinical_info = pd.concat([clinical_info, tsize.loc[clinical_info.index,['tumor size [cm]']]], axis=1)

clinical_info.to_csv('../../../Broad_SpatialFoundation/VisiumHD-LUAD/clinical-info/full_clinical.csv')

In [None]:
clinical_info = pd.read_csv('../../../Broad_SpatialFoundation/VisiumHD-LUAD/clinical-info/full_clinical.csv', index_col=0)

## Get clusters

In [None]:
base_dir = pl.Path('../../../Broad_SpatialFoundation/VisiumHD-LUAD-processed/')
sample_list = np.setdiff1d([f.stem for f in base_dir.iterdir()],['full_cohort','LIB-064888st1'])
sample_list

In [None]:
all_embeddings = {}
for sample in sample_list:
    all_embeddings[sample] = pd.read_parquet(base_dir / f'{sample}/embeddings/NicheFinder.parquet')
    all_embeddings[sample].columns = all_embeddings[sample].columns.astype(str)
   

In [None]:
adatas = []
for sample in tqdm(sample_list):
    adata = sc.read_h5ad(base_dir / sample / 'adata.h5ad')
    embeddings_df = all_embeddings[sample].copy()
    embeddings_df = embeddings_df.set_index('cell_id')
    
    common_idx = adata.obs_names.intersection(embeddings_df.index)
    adata = adata[common_idx].copy()
    embeddings_df = embeddings_df.loc[common_idx]
    adata.obsm['SpatialFusion'] = embeddings_df.loc[:,['0','1','2','3','4','5','6','7','8','9']]
    
    adata.obs_names = adata.obs_names + '::' + sample 
    adata.obs['sample_id'] = sample
    adata = adata[adata.obs.celltypes != 'Noise'].copy()
    adatas.append(adata)
adata = adatas[0].concatenate(*adatas[1:])


In [None]:
embeddings_df = []
for sample in all_embeddings:
    df = all_embeddings[sample].copy()
    df.index = df.cell_id + '::' + df.sample_id
    embeddings_df.append(df)
embeddings_df = pd.concat(embeddings_df)

In [None]:
adata.obs_names = ['-'.join(f) for f in adata.obs_names.str.split('-').str[:-1]]

In [None]:
refined_celltypes = pd.read_csv('full_refined_celltypes.csv', index_col=0)

refined_celltypes = refined_celltypes.replace({'Noise': np.nan})

adata.obs = pd.concat([adata.obs,refined_celltypes[['refined_cellsubtypes','refined_celltypes']]],axis=1)

In [None]:
common_idx = adata.obs_names.intersection(embeddings_df.index)

adata = adata[common_idx].copy()

embeddings_df = embeddings_df.loc[common_idx]

In [None]:
sc.pp.neighbors(adata, use_rep='SpatialFusion')

In [None]:
sc.tl.leiden(adata, resolution=0.2, flavor="igraph", n_iterations=2)

In [None]:
# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)

# Make mapping: old → new (ranked by size)
mapping = {old: str(new) for new, old in enumerate(counts.index)}

# Apply mapping
adata.obs['leiden'] = adata.obs['leiden'].map(mapping).astype('category')

# (optional) sort categories by their new numeric label
adata.obs['leiden'].cat.reorder_categories(sorted(adata.obs['leiden'].cat.categories, key=int))

print("Cluster relabeling done ✅")# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)
display(counts)

In [None]:
embeddings_df['cellsubtypes'] = adata.obs.loc[embeddings_df.index, 'refined_cellsubtypes'].ravel()

In [None]:
embeddings_df['celltypes'] = adata.obs.loc[embeddings_df.index, 'refined_celltypes'].ravel()

In [None]:
embeddings_df['leiden_joint'] = group_small_clusters(
    adata.obs[['leiden']],
    cluster_col='leiden',
    min_count= 1000,
    new_label= "Other",
    output_col = None
)

In [None]:
adata.obs['leiden_joint'] = embeddings_df['leiden_joint']

In [None]:
adata.obs['leiden_joint'] = adata.obs['leiden_joint'].astype(str).astype('category')

In [None]:
adata.obs.to_parquet('nsclc_adata_obs.parquet')

In [None]:
adata_obs = pd.read_parquet('nsclc_adata_obs.parquet')

adata.obs = adata_obs.loc[adata_obs.index]

In [None]:
tab_filtered = sns.color_palette()
tab_filtered = [c for i,c in enumerate(tab_filtered) if i not in [4,6]]

tab20_filtered = sns.color_palette('tab20') + sns.color_palette('tab20c') 
tab20_filtered = [c for i,c in enumerate(tab20_filtered) if i not in [8,9,12,13]]

In [None]:
palette_specs = {
            'leiden_joint': tab20_filtered,
            'refined_cellsubtypes': tab20_filtered,
            'refined_celltypes': tab_filtered,
        }

palette_dict_1 = build_palettes_from_adata(adata, palette_specs)

In [None]:
import json

with open("palettes_NSCLC_Novartis.json", "w") as f:
    json.dump(palette_dict_1, f, indent=2)

In [None]:
adata.obs['X_coord'] = adata.obsm['spatial'][:,0]
adata.obs['Y_coord'] = adata.obsm['spatial'][:,1]

In [None]:
for sample in tqdm(adata.obs.sample_id.unique()):
    print(sample)
    tmp = adata[adata.obs.sample_id==sample].copy()
    plot_celltype_spatial_single_split_legend(
        tmp.obs,
        color_by="refined_celltypes",
        sample_id=sample,
        title=None,
        palette_dict=palette_dict_1,         # ✅ added
        s=1,
        save_svg=True,
        output_prefix=f"../../../SpatialFusion/results/figures_Fig6/celltypes",
        legend_title=None,
    )

In [None]:
for sample in tqdm(adata.obs.sample_id.unique()):
    print(sample)
    tmp = adata[adata.obs.sample_id==sample].copy()
    plot_celltype_spatial_single_split_legend(
        tmp.obs,
        color_by="leiden_joint",
        sample_id=sample,
        title=None,
        palette_dict=palette_dict_1,         # ✅ added
        s=1,
        save_svg=True,
        output_prefix=f"../../../SpatialFusion/results/figures_Fig6/niches",
        legend_title=None,
    )

In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap

# --- Data prep (your same logic) ---
cl_order = np.sort(np.setdiff1d(adata.obs.leiden_joint.unique(),['Other','nan']).astype(int)).astype(str)
sample_order = ['LIB-064885st1','LIB-065291st1','LIB-065292st1','LIB-065294st1','LIB-065295st1', # LUSC
               'LIB-064886st1', 'LIB-064887st1', 'LIB-064889st1', 'LIB-064890st1', 'LIB-065290st1'] # LUAD

vc = adata.obs[['sample_id','leiden_joint']].value_counts().unstack().T
vc = vc / vc.sum(axis=0)
vc = vc.loc[cl_order, sample_order]

# --- Style setup (Nature Genetics–like aesthetic) ---
sns.set_theme(context="talk", style="white")

# --- Custom colormap: subtle, elegant red-to-gray gradient ---
cmap = LinearSegmentedColormap.from_list(
    "vlag_redgray",
    ["#f7f7f7", "#f4a3a8", "#b40426"]
)

# --- Prepare annotation matrix ---
annot = vc.copy() * 100  # convert to percent
annot_fmt = annot.copy()

# format as strings with rules:
for i in range(annot_fmt.shape[0]):
    for j in range(annot_fmt.shape[1]):
        val = annot_fmt.iat[i, j]
        if pd.isna(val):
            annot_fmt.iat[i, j] = "N.A."
        elif val < 1:
            annot_fmt.iat[i, j] = "<1%"
        else:
            annot_fmt.iat[i, j] = f"{val:.0f}%"  # round to nearest percent

# --- Plot ---
fig, ax = plt.subplots(figsize=(3.5, 5), dpi=300)

sns.heatmap(
    vc,
    cmap=cmap,
    annot=annot_fmt,
    fmt="",
    linewidths=0.4,
    linecolor="white",
    cbar=False,
    annot_kws={"fontsize": 7, "color": "black"},
    ax=ax,
)

# --- Aesthetic adjustments ---
ax.set_xlabel("", fontsize=11)
ax.set_ylabel("", fontsize=11)
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right", fontsize=7)
ax.set_yticks(np.arange(len(vc.index)) + 0.5)
ax.set_yticklabels(vc.index, rotation=0, fontsize=7)
ax.tick_params(length=0)

# Remove borders and extra gridlines
for spine in ax.spines.values():
    spine.set_visible(False)

# Optional title
ax.set_title("", fontsize=12, pad=10, fontweight="normal")

plt.tight_layout()
fig.savefig('../../../SpatialFusion/results/figures_Fig6/niches_proportions.svg')
plt.show()


# Plot cell type composition

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Optional, Sequence, Mapping, Tuple, Dict, List
from matplotlib.colors import to_hex, to_rgb
import colorsys


def plot_cluster_composition_stacked(
    df,
    cluster_key: str = "leiden",
    type_key: str = "celltypes",                 # or "cellsubtypes"
    cluster_order: Optional[Sequence[str]] = None,
    strict_order: bool = False,                  # if True, only show clusters in cluster_order
    palette: Optional[Mapping[str, str]] = None, # dict {celltype: color}; auto if None
    top_types: Optional[int] = None,             # keep top N types globally, rest→"Other"
    min_frac: Optional[float] = None,            # keep types with global frac >= min_frac, rest→"Other"
    other_label: str = "Other",
    type_order: Optional[Sequence[str]] = None,  # custom order of stack segments
    figsize=(10, 5),
    percent_labels: bool = False,                # print % on bars
    label_threshold: float = 0.05,               # only label segments >=5%
    savefig: Optional[str] = None,
):
    """
    Plot a 100% stacked barplot of type proportions per cluster using adata.obs.
    Returns a long-form DataFrame with columns: [cluster, type, count, frac, percent]
    """
    obs = df[[cluster_key, type_key]].dropna().copy()
    obs[cluster_key] = obs[cluster_key].astype(str)
    obs[type_key]    = obs[type_key].astype(str)

    # Cross-tab counts (rows=clusters, cols=types)
    ct = pd.crosstab(obs[cluster_key], obs[type_key])

    # Global filtering of rare types (optional)
    keep_cols = ct.columns.tolist()
    if top_types is not None:
        keep_cols = (
            ct.sum(axis=0)
              .sort_values(ascending=False)
              .head(top_types)
              .index.tolist()
        )
    if min_frac is not None:
        global_frac = ct.sum(axis=0) / ct.values.sum()
        keep_cols = sorted(set(keep_cols) | set(global_frac[global_frac >= min_frac].index.tolist()))
    if (top_types is not None) or (min_frac is not None):
        other = ct.drop(columns=keep_cols, errors="ignore").sum(axis=1)
        ct = ct[keep_cols].copy()
        if (other > 0).any():
            ct[other_label] = other
        # make sure "Other" is last
        ct = ct[[c for c in ct.columns if c != other_label] + ([other_label] if other_label in ct.columns else [])]

    # Normalize rows to 1.0 (100%)
    row_sums = ct.sum(axis=1).replace(0, np.nan)
    props = ct.div(row_sums, axis=0).fillna(0.0)

    # Cluster reordering
    if cluster_order is not None:
        cluster_order = [str(c) for c in cluster_order]
        missing = [c for c in cluster_order if c not in props.index]
        if strict_order:
            props = props.reindex(cluster_order).dropna(how="all")
        else:
            extras = [c for c in props.index if c not in cluster_order]
            props = props.reindex(cluster_order + extras)
        if missing:
            print(f"Warning: these clusters from cluster_order were not found and will be skipped: {missing}")
    else:
        props = props.sort_index()

    if props.empty:
        raise ValueError("No clusters to plot after filtering/reordering.")

    # Determine stack (type) order
    types_order = props.columns.tolist()
    if type_order is not None:
        type_order = [t for t in type_order if t in props.columns]
        leftovers = [t for t in props.columns if t not in type_order]
        types_order = type_order + leftovers

    # Build color map
    if palette is None:
        base = sns.color_palette("tab10", n_colors=max(10, len(types_order)))
        colmap = dict(zip(types_order, base[:len(types_order)]))
        if other_label in types_order:
            colmap[other_label] = "#B0B0B0"  # gray for "Other"
    else:
        colmap = {t: palette.get(t, "#BBBBBB") for t in types_order}

    # Plot (stacked bars)
    plt.figure(figsize=figsize)
    bottom = np.zeros(len(props))
    x = np.arange(len(props.index))
    ax = plt.gca()

    for t in types_order:
        vals = props[t].values
        ax.bar(x, vals, bottom=bottom, width=0.9, color=colmap[t], label=t, edgecolor="none")
        bottom += vals

    ax.set_xticks(x)
    ax.set_xticklabels(props.index, rotation=45, ha="right")
    ax.set_ylim(0, 1)
    ax.set_ylabel("Composition (% of cells)")
    ax.set_xlabel(cluster_key)
    ax.set_title(f"{type_key} composition per {cluster_key}")

    if percent_labels:
        for i, cl in enumerate(props.index):
            cum = 0.0
            for t in types_order:
                h = props.loc[cl, t]
                if h >= label_threshold:
                    ax.text(i, cum + h/2, f"{h*100:.0f}%", ha="center", va="center", fontsize=8, color="white")
                cum += h

    ax.legend(bbox_to_anchor=(1.02, 1), loc="upper left", title=type_key)
    plt.tight_layout()

    if savefig:
        plt.savefig(savefig, dpi=200, bbox_inches='tight')
    plt.show()

    # Long-form result for downstream use
    plot_df = (
        props.reset_index()
             .melt(id_vars=cluster_key, var_name="type", value_name="frac")
             .rename(columns={cluster_key: "cluster"})
    )
    plot_df["percent"] = (plot_df["frac"] * 100).round(2)
    counts_long = (
        ct.reset_index()
          .melt(id_vars=cluster_key, var_name="type", value_name="count")
          .rename(columns={cluster_key: "cluster"})
    )
    plot_df = plot_df.merge(counts_long, on=["cluster", "type"], how="left")
    return plot_df

def _lightness_shades(base_color, n, l_low=0.35, l_high=0.85):
    """Generate a sequence of lighter shades from a base color."""
    rgb = to_rgb(base_color)
    h, l, s = colorsys.rgb_to_hls(*rgb)
    ls = np.linspace(l_low, l_high, n)
    shades = [to_hex(colorsys.hls_to_rgb(h, li, s)) for li in ls]
    return shades

def make_hierarchical_palettes(
    df,
    parent_key: str = "celltypes",
    child_key: str  = "cellsubtypes",
    parent_order: Optional[Sequence[str]] = None,
    child_order: str = "alpha",
    base_palette: Optional[Mapping[str, str]] = None,  # legacy param
    parent_palette_dict: Optional[Dict[str, Tuple[float, float, float]]] = None,  # ✅ new
    unknown_parent_color: str = "#9e9e9e",
    shade_lightness: Tuple[float, float] = (0.35, 0.85),
) -> Tuple[Dict[str, str], Dict[str, str], List[str], List[str]]:
    """
    Generate hierarchical palettes with optional parent palette override.

    Parameters
    ----------
    parent_palette_dict : dict, optional
        If provided, should be {parent_label: color (hex or RGB)}.
        These colors are used directly for parent_palette and for shading.

    Returns
    -------
    parent_palette : {parent -> hex}
    child_palette  : {child -> hex}
    parent_order_out : list of parents
    child_order_out  : list of children grouped by parent
    """
    obs = df[[parent_key, child_key]].copy()
    obs[parent_key] = obs[parent_key].astype(str)
    obs[child_key]  = obs[child_key].astype(str)

    # --- Parent order ---
    parents = obs[parent_key].unique().tolist()
    if parent_order is not None:
        parent_order_out = [p for p in parent_order if p in parents] + \
                           [p for p in parents if p not in parent_order]
    else:
        freq = obs[parent_key].value_counts()
        parent_order_out = freq.index.tolist() + [p for p in parents if p not in freq.index]

    # --- Children per parent ---
    children_per_parent = {}
    for p in parents:
        sub = obs.loc[obs[parent_key] == p, child_key]
        if child_order == "freq":
            children = sub.value_counts().index.tolist()
        else:
            children = sorted(sub.unique().tolist())
        children_per_parent[p] = children

    # --- Build parent colors ---
    if parent_palette_dict is not None:
        # ✅ use directly, fallback to gray if missing
        parent_palette = {
            p: to_hex(parent_palette_dict.get(p, unknown_parent_color))
            for p in parent_order_out
        }
    elif base_palette is not None:
        parent_palette = {p: base_palette.get(p, unknown_parent_color) for p in parent_order_out}
    else:
        n_par = len(parent_order_out)
        base = sns.color_palette("tab10" if n_par <= 10 else "hls", n_colors=n_par)
        parent_palette = {p: to_hex(base[i]) for i, p in enumerate(parent_order_out)}

    # --- Build child colors as shades of parent ---
    l_low, l_high = shade_lightness
    child_palette = {}
    child_order_out = []
    for p in parent_order_out:
        base_col = parent_palette.get(p, unknown_parent_color)
        kids = children_per_parent.get(p, [])
        if not kids:
            continue
        shades = _lightness_shades(base_col, len(kids), l_low, l_high)
        for k, col in zip(kids, shades):
            child_palette[k] = col
        child_order_out.extend(kids)

    return parent_palette, child_palette, parent_order_out, child_order_out



In [None]:
# Build coherent palettes & orders
parent_pal, child_pal, parent_order, child_order = make_hierarchical_palettes(
    adata.obs,
    parent_key="refined_celltypes",
    child_key="refined_cellsubtypes",
    parent_palette_dict=palette_dict_1['refined_celltypes'],
    child_order="alpha",                 # or "freq"
    shade_lightness=(0.35, 0.85)
)

# (A) 100% stacked bars by lineage (coarse)
plot_df_types = plot_cluster_composition_stacked(
    adata.obs,
    cluster_key="leiden_joint",
    type_key="refined_celltypes",
    cluster_order=[f"{i}" for i in range(18)],  # your preferred cluster order
    strict_order=False,
    palette=parent_pal,
    percent_labels=True,
    figsize=(10, 4),
    savefig='../../../SpatialFusion/results/figures_Fig6/NSCLC_major_celltype_stacked_finetuned_barplot.svg',
)

# (B) 100% stacked bars by subtypes (fine), colors are shades within lineage color
plot_df_subtypes = plot_cluster_composition_stacked(
    adata.obs,
    cluster_key="leiden_joint",
    type_key="refined_cellsubtypes",
    cluster_order=[f"{i}" for i in range(18)],
    strict_order=False,
    palette=child_pal,
    # NEW: order subtypes grouped by their parent lineage
    type_order=child_order,              # requires tiny tweak shown above
    percent_labels=True,
    figsize=(15, 10),
    savefig='../../../SpatialFusion/results/figures_Fig6/NSCLC_minor_celltype_stacked_finetuned_barplot.svg',
)


In [None]:
malignant_niches = ['0','1','2','4','5','7','10','12','15','16',]
malignant_immune_niches = ['2','13','18','20',]
malignant_stromal_niches = ['8',]
tme_niches = ['3','6','9','11','14',]

In [None]:
adata.obs['Histology'] = adata.obs['sample_id'].replace(clinical_info[['Library','Histology']].set_index('Library').to_dict()['Histology'])

In [None]:
cts = adata.obs[['Histology','leiden_joint']].value_counts().unstack()
cts = cts/cts.sum(axis=0)

In [None]:
fig, ax = plt.subplots(1,1,figsize=(15,1))
sns.heatmap(cts.loc[:,[f'{i}' for i in range(18)]], annot=True, fmt='.1f', cmap='vlag', center=0.5)
ax.set_xlabel('')
ax.set_ylabel('')

# Associate with clinical characteristics


In [None]:
import pandas as pd
import numpy as np
from scipy.stats import mannwhitneyu, kruskal, spearmanr

def analyze_cluster_associations(df, cluster_cols=None, alpha=0.05):
    """
    For each cluster column (continuous), test association with each clinical variable.
    
    Parameters
    ----------
    df : pd.DataFrame
        DataFrame containing clinical variables and cluster proportions.
    cluster_cols : list, optional
        List of cluster proportion columns. If None, assumes numeric column names like '0', '1', ...
    alpha : float
        Significance level (optional, just for filtering output if desired)
    
    Returns
    -------
    results : pd.DataFrame
        Tidy table with: cluster, variable, variable_type, test, statistic, p_value
    """
    results = []

    # Automatically detect cluster columns if not provided
    if cluster_cols is None:
        cluster_cols = [col for col in df.columns if col.isnumeric() or col.isdigit()]
    
    # Other columns = clinical variables
    clinical_cols = [col for col in df.columns if col not in cluster_cols]
    
    for cluster in cluster_cols:
        y = df[cluster]
        for var in clinical_cols:
            x = df[var]
            
            # Drop missing values jointly
            tmp = pd.concat([x, y], axis=1).dropna()
            if tmp.shape[0] < 4:
                continue  # skip too small sample

            x = tmp[var]
            y = tmp[cluster]
            
            # Determine variable type
            nunique = x.nunique()
            if nunique == 2:
                var_type = 'binary'
                test = 'mannwhitney'
                groups = [y[x == level] for level in x.unique()]
                if len(groups) == 2:
                    stat, p = mannwhitneyu(groups[0], groups[1], alternative='two-sided')
                else:
                    stat, p = np.nan, np.nan

            elif 2 < nunique < 10 and x.dtype == 'object' or str(x.dtype).startswith('category'):
                var_type = 'categorical'
                test = 'kruskal'
                groups = [y[x == level] for level in x.unique()]
                if len(groups) > 1:
                    stat, p = kruskal(*groups)
                else:
                    stat, p = np.nan, np.nan

            else:
                var_type = 'continuous'
                test = 'spearman'
                stat, p = spearmanr(x, y)
            
            results.append({
                'cluster': cluster,
                'variable': var,
                'variable_type': var_type,
                'test': test,
                'statistic': stat,
                'p_value': p,
                'n': tmp.shape[0]
            })
    
    res_df = pd.DataFrame(results)

    # Adjust p-values per test family (optional)
    from statsmodels.stats.multitest import multipletests
    res_df['p_adj'] = multipletests(res_df['p_value'].fillna(1), method='fdr_bh')[1]
    res_df = res_df.sort_values('p_adj')

    return res_df

In [None]:
def leave_one_out_stability(df, 
                            patient_col='patient_id', 
                            cluster_cols=None, 
                            alpha=0.05):
    """
    Runs leave-one-out association analysis across patients and summarizes stability.

    Returns
    -------
    summary_df : pd.DataFrame
        Summary table with cluster, variable, mean_p_value, sd_p_value,
        n_significant, and prop_significant.
    pval_matrix : pd.DataFrame
        Wide-format table with cluster-variable pairs as rows and one column per left-out patient,
        containing p-values from each run.
    all_runs_df : pd.DataFrame
        Long-format table with all p-values and metadata.
    """

    patients = df[patient_col].unique()
    all_results = []

    for pid in patients:
        df_loo = df[df[patient_col] != pid].copy()
        res = analyze_cluster_associations(df_loo, cluster_cols=cluster_cols, alpha=alpha)
        res['left_out'] = pid
        all_results.append(res)

    # Combine all runs
    all_runs_df = pd.concat(all_results, ignore_index=True)

    # Pivot to get p-values per patient (wide form)
    pval_matrix = (
        all_runs_df
        .pivot_table(index=['cluster', 'variable'],
                     columns='left_out',
                     values='p_value')
        .reset_index()
    )

    # Summarize stability
    summary_df = (
        all_runs_df
        .groupby(['cluster', 'variable'], as_index=False)
        .agg(
            mean_p_value=('p_value', 'mean'),
            sd_p_value=('p_value', 'std'),
            n_runs=('left_out', 'nunique'),
            n_significant=('p_value', lambda x: np.sum(x < alpha))
        )
    )
    summary_df['prop_significant'] = summary_df['n_significant'] / summary_df['n_runs']
    summary_df = summary_df.sort_values(['prop_significant', 'mean_p_value'], ascending=[False, True])

    return summary_df, pval_matrix, all_runs_df


In [None]:
prop_table = vc.T.fillna(0)

In [None]:
df = pd.concat([clinical_info.drop(['Gender','Age'],axis=1).set_index('Library'),prop_table], axis=1)
df = df.rename(columns={'tumor size [cm]': 'tumor size'})

In [None]:
df['patient_id'] = df.index

In [None]:
# Example usage:
# df = your clinical+proportion dataframe
results_df = analyze_cluster_associations(df)

In [None]:
results_df = results_df.sort_values('p_value')
res = results_df[results_df['p_value']<0.05]
res.to_csv('../../../SpatialFusion/results/figures_Fig6/results_dissociated.csv')

res

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

def plot_single_boxplot(df, cluster, var, pval, order, savefig=None):
    # Extract non-missing data
    tmp = df[[cluster, var]].dropna()
    y = tmp[cluster]
    x = tmp[var]
    
    fig, ax = plt.subplots(1,1,figsize=(3,2))
    sns.boxplot(x=x, y=y, order=order, palette="Set2", width=0.6, ax=ax)
    sns.stripplot(x=x, y=y, order=order, color="black", alpha=0.6, size=5, ax=ax)
    ax.spines[['right', 'top']].set_visible(False)
    plt.title(f"Cluster {cluster} vs {var}\n(p = {pval:.3g})")
    plt.xlabel('')
    plt.ylabel(f"Cluster {cluster} proportion")
    if savefig is not None:
        fig.savefig(savefig, dpi=200, bbox_inches='tight')

def plot_single_lineplot(df, cluster, var, pval, statistic, savefig=None):
    # Extract non-missing data
    tmp = df[[cluster, var]].dropna()
    y = tmp[cluster]
    x = tmp[var]
    
    fig, ax = plt.subplots(1,1,figsize=(3,2))
    sns.regplot(x=x, y=y, scatter_kws={"s": 60, "alpha": 0.7}, line_kws={"color": "red"})
    ax.spines[['right', 'top']].set_visible(False)
    plt.title(f"Cluster {cluster} vs {var}\nSpearman ρ={statistic:.2f}, p={pval:.3g}")
    plt.xlabel('')
    plt.ylabel(f"Cluster {cluster} proportion")
    if savefig is not None:
        fig.savefig(savefig, dpi=200, bbox_inches='tight')

In [None]:
cluster = '10'
var = 'tumor size'
pval = results_df.loc[427].p_value
statistic=results_df.loc[427].statistic

plot_single_lineplot(df, cluster, var, pval, statistic, savefig='../../../SpatialFusion/results/figures_Fig6/tumorsize_assoc_cl10.svg')

In [None]:
cluster = '16'
var = 'pTNM T red'
pval = results_df.loc[631].p_value
order=['T1/2', 'T3/4']

plot_single_boxplot(df, cluster, var, pval, order, 
                    savefig='../../../SpatialFusion/results/figures_Fig6/Tred_assoc_cl16.svg')

## Compare against dissociated

In [None]:
# --- Data prep (your same logic) ---
sample_order = ['LIB-064885st1','LIB-065291st1','LIB-065292st1','LIB-065294st1','LIB-065295st1', # LUSC
               'LIB-064886st1', 'LIB-064887st1', 'LIB-064889st1', 'LIB-064890st1', 'LIB-065290st1'] # LUAD

vc_ct = adata.obs[['sample_id','refined_cellsubtypes']].value_counts().unstack().T
vc_ct = vc_ct / vc_ct.sum(axis=0)
vc_ct = vc_ct.loc[:, sample_order]

In [None]:
prop_table = vc_ct.T.fillna(0)

In [None]:
prop_table = prop_table.loc[:,(prop_table<0.01).sum(axis=0)[(prop_table<0.01).sum(axis=0)<8].index]

In [None]:
df = pd.concat([clinical_info.drop(['Gender','Age'],axis=1).set_index('Library'),prop_table], axis=1)
df = df.rename(columns={'tumor size [cm]': 'tumor size'})

In [None]:
df['patient_id'] = df.index

In [None]:
# Example usage:
# df = your clinical+proportion dataframe
results_df = analyze_cluster_associations(df, cluster_cols=prop_table.columns)

In [None]:
results_df = results_df.sort_values('p_value')
res = results_df[results_df['p_value']<0.05]
res.to_csv('../../../SpatialFusion/results/figures_Fig6/results_dissociated.csv')

res

# Characterize niches

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

def plot_with_highlights(
    df,
    title,
    highlight_clusters,
    cats,
    base_palette,
    highlight_palette,
    *,
    x_col="X_coord",
    y_col="Y_coord",
    cluster_col="leiden_joint",
    s_base=1.2,
    s_highlight=5,
    alpha_base=0.25,
    alpha_highlight=0.9,
    rasterized=True,
    savefig=None,
):
    """
    Nature Genetics–style spatial scatterplot with selected clusters highlighted.

    Parameters
    ----------
    df : pd.DataFrame
        Must contain coordinates and cluster labels.
    title : str
        Title or sample ID.
    highlight_clusters : list
        Clusters to highlight.
    cats : list
        Ordered list of cluster categories.
    base_palette : dict
        {cluster: color} mapping for all clusters (muted).
    highlight_palette : dict
        {cluster: color} mapping for highlighted clusters (saturated).
    """
    sns.set_style("white")
    sns.set_context("talk")

    fig, ax = plt.subplots(figsize=(6, 5), dpi=300)

    # --- Plot non-highlighted clusters (background) ---
    non_highlight = [c for c in cats if c not in highlight_clusters]
    df_bg = df[df[cluster_col].isin(non_highlight)]
    if not df_bg.empty:
        sns.scatterplot(
            data=df_bg,
            x=x_col, y=y_col,
            hue=cluster_col,
            hue_order=non_highlight,
            palette={c: highlight_palette[c] for c in non_highlight},
            s=s_base,
            alpha=alpha_base,
            linewidth=0,
            legend=False,
            rasterized=rasterized,
            ax=ax,
        )

    # --- Plot highlighted clusters on top (foreground) ---
    df_fg = df[df[cluster_col].isin(highlight_clusters)]
    if not df_fg.empty:
        sns.scatterplot(
            data=df_fg,
            x=x_col, y=y_col,
            hue=cluster_col,
            hue_order=highlight_clusters,
            palette={c: highlight_palette[c] for c in highlight_clusters},
            s=s_highlight,
            alpha=alpha_highlight,
            linewidth=0,
            legend=False,
            rasterized=rasterized,
            ax=ax,
        )

    # --- Aesthetic tuning (Nature Genetics style) ---
    ax.invert_yaxis()
    ax.set_aspect("equal", adjustable="box")
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xlabel("")
    ax.set_ylabel("")
    for spine in ["top", "right", "left", "bottom"]:
        ax.spines[spine].set_visible(False)

    ax.set_title(title, pad=6, fontsize=12, fontweight="normal")

    plt.tight_layout()
    if savefig is not None:
        fig.savefig(savefig, bbox_inches='tight', dpi=250)
    plt.show()


In [None]:
# --- Choose clusters to highlight (string labels!)
highlight_clusters = ['10','16']
other_important = ['0','1','2','4','5','7','12','15']

# --- Colors
light_gray = "lightgray"
dark_gray  = "dimgray"   # or "gray40", "#555555"

# --- Build palette
highlight_palette = {}
for lbl in cats:
    if lbl in highlight_clusters:
        # keep original color
        highlight_palette[lbl] = palette_dict_1['leiden_joint'].get(lbl, light_gray)
    elif lbl in other_important:
        highlight_palette[lbl] = dark_gray
    else:
        highlight_palette[lbl] = light_gray

# --- Prep
emb = adata.obs.copy()
emb["leiden_joint"] = emb["leiden_joint"].astype(str)  

# --- Global, ordered categories across ALL samples
all_labels = pd.Index(pd.unique(emb["leiden_joint"]))
def sort_key(lbl):
    return (0, int(lbl)) if lbl.isdigit() else (1, lbl)
cats = sorted(map(str, all_labels), key=sort_key)

from pandas.api.types import CategoricalDtype
cat_type = CategoricalDtype(categories=cats, ordered=True)
emb["leiden_joint"] = emb["leiden_joint"].astype(cat_type)

# --- Apply to ANY number of samples
for sample_id, df_sample in emb.groupby("sample_id", sort=True):
    plot_with_highlights(df_sample, sample_id, highlight_clusters, cats, palette_dict_1['leiden_joint'], highlight_palette, s_base=1,
    s_highlight=2, savefig=f'../../../SpatialFusion/results/figures_Fig6/{sample_id}_highlight_10_16.svg')


# Pathway activation

In [None]:
def summarize_cluster_pathways(
    adata_obs,
    pathways,
    cluster_col="kmeans_cluster",
    clusters=None,
    global_cluster_order=None
):
    """
    Returns a tidy DataFrame with mean pathway scores per cluster.
    Columns: Cluster, Pathway, mean, std, n, sem
    """
    if cluster_col not in adata_obs.columns:
        raise ValueError(f"'{cluster_col}' not found in adata.obs")

    missing = [p for p in pathways if p not in adata_obs.columns]
    if missing:
        raise ValueError(f"These pathways are missing in adata.obs: {missing}")

    df = adata_obs[[cluster_col] + pathways].copy()
    df[cluster_col] = df[cluster_col].astype(str)

    # cluster universe
    if global_cluster_order is None:
        all_clusters = sorted(df[cluster_col].unique())
    else:
        all_clusters = [str(c) for c in global_cluster_order]

    # subset clusters if requested
    if clusters is not None:
        selected = [str(c) for c in clusters]
    else:
        selected = all_clusters

    df = df[df[cluster_col].isin(selected)].copy()

    # melt -> summarize -> tidy
    long = df.melt(id_vars=[cluster_col], value_vars=pathways,
                   var_name="Pathway", value_name="Score").dropna(subset=["Score"])

    summary = (
        long.groupby([cluster_col, "Pathway"], as_index=False)
            .agg(mean=("Score", "mean"),
                 std =("Score", "std"),
                 n   =("Score", "size"))
    )
    summary["sem"] = summary["std"] / np.sqrt(summary["n"].clip(lower=1))
    summary.rename(columns={cluster_col: "Cluster"}, inplace=True)

    # keep a consistent cluster order
    summary["Cluster"] = pd.Categorical(summary["Cluster"], categories=selected, ordered=True)
    summary = summary.sort_values(["Cluster", "Pathway"]).reset_index(drop=True)
    return summary

In [None]:
def plot_pathway_bars_by_cluster(
    adata_obs,
    pathways,
    cluster_col="kmeans_cluster",
    clusters=None,
    global_cluster_order=None,
    err="sem",        # "sem" or "ci"
    ci_level=95, 
    figsize=None,
    ylim=None,
    palette=None,     # <-- NEW: user-specified palette dict or name
    savefig=None,
):
    """
    Grouped bar chart of mean pathway scores per pathway (x),
    colored by cluster (hue). Error bars: SEM or normal-approx CI.

    Parameters
    ----------
    palette : dict, list, or str, optional
        A seaborn-compatible color palette or dictionary mapping
        cluster labels to colors. If None, defaults to tab20.
    """
    from scipy.stats import norm

    summary = summarize_cluster_pathways(
        adata_obs, pathways, cluster_col=cluster_col,
        clusters=clusters, global_cluster_order=global_cluster_order
    )

    # keep the Pathway order as provided by `pathways`
    summary["Pathway"] = pd.Categorical(summary["Pathway"], categories=pathways, ordered=True)
    summary = summary.sort_values(["Pathway", "Cluster"]).reset_index(drop=True)

    # compute error quantity
    if err == "ci":
        z = norm.ppf(0.5 + ci_level/200.0)
        summary["yerr"] = z * summary["sem"]
        err_label = f"{ci_level}% CI"
    else:
        summary["yerr"] = summary["sem"]
        err_label = "SEM"

    # hue palette for clusters
    cluster_levels = summary["Cluster"].cat.categories.tolist()
    n_clusters = len(cluster_levels)

    if palette is None:
        hue_palette = sns.color_palette("tab20", n_colors=max(10, n_clusters))
        cluster_palette = {c: hue_palette[i % len(hue_palette)] for i, c in enumerate(cluster_levels)}
    elif isinstance(palette, dict):
        cluster_palette = palette
    else:
        # can be a seaborn palette name or list
        cluster_palette = sns.color_palette(palette, n_colors=n_clusters)

    # figure size scales with #pathways
    if figsize is None:
        figsize = (max(8, 0.8 * len(pathways)), 5 + 0.15 * n_clusters)

    plt.figure(figsize=figsize)
    ax = sns.barplot(
        data=summary,
        x="Pathway", y="mean",
        hue="Cluster",
        palette=cluster_palette,
        dodge=True,
        errorbar=None
    )

    # manual error bars
    x_categories = list(summary["Pathway"].cat.categories)
    n_x = len(x_categories)
    n_hue = n_clusters
    bar_width = 0.8 / max(1, n_hue)
    x_positions = {p: i for i, p in enumerate(x_categories)}

    for j, c in enumerate(cluster_levels):
        sub = summary[summary["Cluster"] == c].reset_index(drop=True)
        sub = sub.set_index("Pathway").reindex(x_categories).reset_index()
        xs = [x_positions[p] - 0.4 + (j + 0.5) * bar_width for p in sub["Pathway"]]
        ax.errorbar(xs, sub["mean"], yerr=sub["yerr"], fmt="none",
                    ecolor="black", elinewidth=1, capsize=3, capthick=1)

    ax.set_title(f"Mean pathway score per pathway (± {err_label})")
    ax.set_ylabel("Mean pathway score")
    ax.set_xlabel("Pathway")
    if ylim is not None:
        ax.set_ylim(ylim)
    ax.legend(title="Cluster", bbox_to_anchor=(1.02, 1), loc="upper left")
    plt.tight_layout()
    if savefig is not None:
        plt.savefig(savefig, dpi=200, bbox_inches='tight')
    plt.show()

    return summary


In [None]:
path_mat = {}
for sample in sample_list:
    path_mat[sample] = pd.read_parquet(base_dir / sample / 'pathway_activation.parquet')
    path_mat[sample].index = path_mat[sample].index.astype(str) + '::' + sample

path_mat = pd.concat(path_mat.values())

In [None]:
#adata.obs = pd.concat([adata.obs, path_mat.loc[adata.obs_names]],axis=1)

In [None]:
pathways = ['Androgen','EGFR','Estrogen','JAK-STAT','MAPK','NFkB','PI3K','TGFb','TNFa','VEGF']

plot_pathway_bars_by_cluster(
    adata.obs,
    pathways=pathways,
    cluster_col="leiden_joint",
    global_cluster_order=['0','1','2','3','4','5','6','7','8','9','10','11','12','13','14','15','16','17',],
    palette=palette_dict_1['leiden_joint'],
    err="ci",        # "sem" or "ci"
    ci_level=95, 
    figsize=(20,7),
    savefig='../../../SpatialFusion/results/figures_Fig6/pathway_score_1_barplot.svg'
)


In [None]:
pathways = ['Androgen','EGFR','Estrogen','JAK-STAT','MAPK','NFkB','PI3K','TGFb','TNFa','VEGF']

plot_pathway_bars_by_cluster(
    adata.obs,
    pathways=pathways,
    cluster_col="leiden_joint",
    global_cluster_order=malignant_niches,
    palette=palette_dict_1['leiden_joint'],
    err="ci",        # "sem" or "ci"
    ci_level=95, 
    figsize=(15,4),
    savefig='../../../SpatialFusion/results/figures_Fig6/pathway_score_malignant_niches_barplot.svg'
)


# Transform adata

In [None]:
adata.obs['X'] = adata.obsm['spatial'][:,0]
adata.obs['Y'] = adata.obsm['spatial'][:,1]

In [None]:
adata.layers['counts'] = adata.X.copy()

In [None]:
sc.pp.normalize_total(adata, target_sum=10000)
sc.pp.log1p(adata)

# Niches 10 and 16

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.path import Path
from matplotlib.patches import PathPatch, Rectangle
from typing import Optional, Tuple, Dict, Sequence

def _ribbon_path(x0, x1, y0_bot, y0_top, y1_bot, y1_top, curvature=0.5, steps=40):
    xs = np.linspace(x0, x1, steps)
    t = np.linspace(0, 1, steps)
    s = 3*t**2 - 2*t**3
    s = (1 - curvature) * t + curvature * s
    top = (1 - s) * y0_top + s * y1_top
    bot = (1 - s) * y0_bot + s * y1_bot
    verts = np.concatenate([np.column_stack([xs, bot]),
                            np.column_stack([xs[::-1], top[::-1]])])
    codes = np.array([Path.MOVETO] + [Path.LINETO]*(len(xs)-1)
                     + [Path.LINETO]*len(xs) + [Path.CLOSEPOLY])
    verts = np.vstack([verts, verts[0]])
    return Path(verts, codes)

def alluvial_multi_groups(
    data: pd.DataFrame,
    group_col: str = "niche",
    subtype_col: str = "cell_subtype",
    weight_col: Optional[str] = None,
    normalize: bool = True,
    min_frac_to_label: float = 0.06,
    title: Optional[str] = None,
    outfile_prefix: Optional[str] = None,
    palette: Optional[Dict[str, str]] = None,
    subtype_order: Optional[Sequence[str]] = None,
    group_order: Optional[Sequence[str]] = None,
    figsize=(8, 4),
    curvature: float = 0.6,
) -> Tuple[plt.Figure, plt.Axes, pd.DataFrame, pd.DataFrame]:
    """Generalized alluvial plot for an arbitrary number of groups."""
    groups = list(data[group_col].astype(str).unique())
    if group_order is not None:
        missing = [g for g in group_order if g not in groups]
        if missing:
            raise ValueError(f"These groups not found in data: {missing}")
        groups = list(group_order)
    else:
        groups = sorted(groups)

    if len(groups) < 2:
        raise ValueError("At least two groups are required for an alluvial plot.")

    if weight_col is None:
        data = data.copy()
        data["_w"] = 1.0
        weight_col = "_w"

    # Aggregate counts
    counts = (
        data.groupby([group_col, subtype_col])[weight_col].sum()
        .unstack(fill_value=0.0)
        .loc[groups]
    )

    # Subtype order
    if subtype_order is not None:
        present = [s for s in subtype_order if s in counts.columns]
        missing = [s for s in counts.columns if s not in present]
        ordered_cols = present + missing
    else:
        ordered_cols = counts.sum(axis=0).sort_values(ascending=False).index.tolist()

    counts = counts[ordered_cols]
    totals = counts.sum(axis=1)
    props = counts.div(totals.values[:, None]).fillna(0.0)

    # Plot setup
    fig, ax = plt.subplots(figsize=figsize, dpi=150)
    n_groups = len(groups)
    x_positions = np.linspace(0, 1, n_groups)
    bar_width = 0.1
    gap = 0.02

    # Colors
    subtype_colors = {}
    for s in counts.columns:
        if palette and s in palette:
            subtype_colors[s] = palette[s]
        else:
            subtype_colors[s] = ax._get_lines.get_next_color()

    # Compute y boundaries per group
    y_positions = {}
    for gi, g in enumerate(groups):
        y_positions[g] = {}
        y0 = 0.0
        for s in counts.columns:
            h = props.loc[g, s] if normalize else counts.loc[g, s] / totals.loc[g]
            y_positions[g][s] = (y0, y0 + h)
            y0 += h

    # Draw bars
    for gi, g in enumerate(groups):
        x = x_positions[gi]
        ax.add_patch(Rectangle((x - bar_width/2, 0), bar_width, 1.0,
                               fill=False, lw=0.5))
        for s in counts.columns:
            yb, yt = y_positions[g][s]
            ax.add_patch(Rectangle((x - bar_width/2, yb),
                                   bar_width, yt - yb,
                                   facecolor=subtype_colors[s],
                                   edgecolor='none'))

    # Draw ribbons between consecutive groups
    for gi in range(n_groups - 1):
        g0, g1 = groups[gi], groups[gi + 1]
        x0, x1 = x_positions[gi] + bar_width/2 + gap, x_positions[gi+1] - bar_width/2 - gap

        for s in counts.columns:
            c = subtype_colors[s]
            y0b, y0t = y_positions[g0][s]
            y1b, y1t = y_positions[g1][s]
            path = _ribbon_path(x0, x1, y0b, y0t, y1b, y1t, curvature=curvature, steps=60)
            ax.add_patch(PathPatch(path, facecolor=c, alpha=0.6, edgecolor='none'))

    # Cosmetics
    ax.set_xlim(-gap, 1 + gap)
    ax.set_ylim(0, 1)
    ax.set_xticks(x_positions)
    ax.set_xticklabels(groups)
    ax.set_ylabel("Fraction" if normalize else "Normalized height")
    if title is None:
        title = "Alluvial plot of subtype composition across groups"
    ax.set_title(title)
    ax.grid(False)

    # Labels
    for gi, g in enumerate(groups):
        x = x_positions[gi]
        for s in counts.columns:
            yb, yt = y_positions[g][s]
            frac = yt - yb
            if frac >= min_frac_to_label:
                ax.text(x, yb + frac/2, f"{frac*100:.0f}%", ha='center', va='center', fontsize=8)

    # Legend
    handles = [Rectangle((0,0),1,1, facecolor=subtype_colors[s], edgecolor='none') for s in counts.columns]
    ax.legend(handles, counts.columns.tolist(), title=subtype_col,
              bbox_to_anchor=(1.02, 1), loc="upper left", borderaxespad=0.)

    fig.tight_layout()
    if outfile_prefix:
        fig.savefig(outfile_prefix, bbox_inches='tight', dpi=200)

    return fig, ax, counts, props


In [None]:
import pandas as pd
import numpy as np
import scanpy as sc
import matplotlib.pyplot as plt

def plot_DEG_dotplot_multi_niches(
    adata,
    niche_genes: dict,
    gene_group,
    is_contaminant,
    remove_contaminants: bool = True,
    cmap: str = "RdBu_r",
    save_path: str = None,
    vmin: float = -0.5,
    vmax: float = 0.5,
    dot_max: float = 0.6,
    dot_min: float = 0.05,
    groupby: str = "leiden_joint",
    figsize=(9, 3),
):
    """
    Create a dotplot of DEGs across an arbitrary number of niches.

    Parameters
    ----------
    adata : AnnData
        AnnData object containing expression data.
    niche_genes : dict
        Dictionary where keys are niche labels (str) and values are lists of genes.
        Example: {"Niche 1": niche1_genes, "Niche 2": niche2_genes, "Niche 3": niche3_genes}
    gene_group : callable
        Function mapping gene names -> functional category string.
    is_contaminant : callable
        Function returning True if a gene is a contaminant.
    remove_contaminants : bool
        Whether to filter out contaminants.
    cmap : str
        Colormap for Scanpy dotplot.
    save_path : str, optional
        Path to save figure.
    vmin, vmax : float
        Min and max for color scaling.
    dot_max, dot_min : float
        Dot size limits.
    groupby : str
        Column in adata.obs to group cells by.
    figsize : tuple
        Figure size.
    """

    # ----------------------------------
    # 1. Combine and clean gene lists
    # ----------------------------------
    records = []
    for niche_label, genes in niche_genes.items():
        # Remove contaminants if requested
        if remove_contaminants:
            genes = [g for g in genes if not is_contaminant(g)]
        # Annotate with function
        for g in genes:
            f = gene_group(g)
            if f != "Other":  # skip unclassified
                records.append({"gene": g, "niche": niche_label, "function": f})

    if not records:
        raise ValueError("No valid genes remaining after filtering contaminants or 'Other' category.")

    df = pd.DataFrame(records)
    # to avoid duplicates
    df = df.drop_duplicates(subset=["gene", "function"], keep="first")
    # ----------------------------------
    # 2. Sort genes by niche, function, gene
    # ----------------------------------
    df = df.sort_values(["function", "niche", "gene"])
    grouped_genes = df["gene"].tolist()

    # ----------------------------------
    # 3. Plot with Scanpy dotplot
    # ----------------------------------
    sc.set_figure_params(dpi=200, dpi_save=300, facecolor="white")

    dp = sc.pl.dotplot(
        adata,
        var_names=grouped_genes,
        groupby=groupby,
        vmin=vmin, vmax=vmax,
        cmap=cmap,
        dot_max=dot_max,
        dot_min=dot_min,
        show=False,
        return_fig=True,
        figsize=figsize,
    )

    fig = dp

    # ----------------------------------
    # 4. Save if requested
    # ----------------------------------
    if save_path:
        fig.savefig(save_path, bbox_inches="tight", dpi=300)
        print(f"✅ Saved: {save_path}")

    plt.show()
    return df


# How do malignant change?

In [None]:
subadata = adata[(adata.obs.refined_celltypes.isin(['Malignant'])) & (adata.obs.leiden_joint.isin(malignant_niches))].copy()

In [None]:
palette_specs = {
            'refined_cellsubtypes': tab_filtered,
            'sample_id': tab_filtered,
        }

palette_dict_2 = build_palettes_from_adata(subadata, palette_specs)

In [None]:
avg_expr = pd.Series(np.asarray(subadata.X.mean(axis=0)).ravel(), index=subadata.var_names)

hex_subadata = subadata[:,avg_expr[avg_expr>0.2].index].copy()

In [None]:
sc.tl.rank_genes_groups(hex_subadata, groupby='leiden_joint', method='wilcoxon')

In [None]:
dgex = {}
for gr in hex_subadata.obs.leiden_joint.unique():
    dgex[gr] = sc.get.rank_genes_groups_df(hex_subadata, group=gr)

In [None]:
for cl in dgex:
    print(cl)
    print(dgex[cl].head(50).names.ravel())

In [None]:
tmp = subadata.copy()

sc.pp.scale(tmp)

In [None]:
tmp = tmp[tmp.obs.leiden_joint!='Other'].copy()

In [None]:

def gene_group(gene: str) -> str:
    """
    Assign malignant DEGs from clusters 10 and 16
    into biologically meaningful tumor cell programs,
    with explicit emphasis on these two clusters.
    """
    g = gene.upper()

    # ---------- CLUSTER 16–DOMINANT PROGRAMS ----------


    # ECM remodeling / mesenchymal transition
    if g in {
        "SPARC", "DCN", "LUM", "BGN", "VCAN",
        "VIM", "TAGLN", "CTHRC1", "GPNMB", "TNFAIP2", "CTSB", "CTSZ"
    }:
        return "Cluster16: ECM remodeling / mesenchymal"

    # Secretory / ER–stress / UPR machinery
    if g in {
        "XBP1", "TXNDC5", "SSR4", "NUPR1",
        "PSAP", "LGALS1", "IGFBP7"
    }:
        return "Cluster16: secretory / ER-stress (UPR)"

    # Mitochondrial & redox metabolism (cluster 16–enriched)
    if g in {
        "SOD2", "CYBA", "FTL", "FTH1"
    }:
        return "Cluster16: mitochondrial / redox metabolism"

    # ---------- CLUSTER 10–DOMINANT PROGRAMS ----------

    # Basal / junctional epithelial state
    if g in {
        "KRT5", "KRT8", "KRT15", "KRT17", "KRT19",
        "CLDN1", "CLDN4", "JUP"
    }:
        return "Cluster10: basal / junctional epithelium"

    # Stress-response, hypoxia, oncogenic signaling
    if g in {
        "HSPB1", "DDIT4", "NR2F2", "SOX4", "STAT3",
        "CRELD2", "ERO1A", "TFRC", "S100A8", "S100A9"
    }:
        return "Cluster10: stress / hypoxia / signaling"

    # Proliferation / growth control
    if g in {"SGK1", "KLF5", "CCNL2", "ZNF217", "EIF4EBP1"}:
        return "Cluster10: proliferation / growth control"

    # Cytoskeletal organization / motility
    if g in {"EZR", "PFN2", "ACTG1", "ARL4C"}:
        return "Cluster10: cytoskeletal / motility"

    # Glycolytic & lipid metabolism
    if g in {"ENO1", "FASN", "HBS1L", "PTMS"}:
        return "Cluster10: metabolic adaptation"

    # Receptor & guidance signaling
    if g in {"FGFR1", "PLXNB2"}:
        return "Cluster10: receptor / guidance signaling"

    # Fallback
    return "Other"



def is_contaminant(g):
    g = g.upper()
    return (
        g.startswith("RPS")
        or g.startswith("RPL")
        or g.startswith("HIST")
        or any(k in g for k in [
            # endothelial / stromal contamination markers
            "VWF", "PECAM", "LYVE1", "TEK", "COL",
            "ASPN", "MFAP", "TFPI2", "EDNRA", "NOTCH3", "RAMP3"
        ])
    )


In [None]:
niche_genes = {f'Niche {i}': dgex[f'{i}'].head(50).names.ravel() for i in malignant_niches}

In [None]:
df = plot_DEG_dotplot_multi_niches(
    tmp,
    niche_genes=niche_genes,
    gene_group=gene_group,
    is_contaminant=is_contaminant,
    groupby="leiden_joint",
    save_path="../../../SpatialFusion/results/figures_Fig6/dotplot_malignant_deg.svg",
    vmin=-0.25, vmax=0.25, dot_max=0.3, figsize=(11, 3),
)


In [None]:
df = plot_DEG_dotplot_multi_niches(
    tmp[tmp.obs.leiden_joint.isin(['10','16'])],
    niche_genes=niche_genes,
    gene_group=gene_group,
    is_contaminant=is_contaminant,
    groupby="leiden_joint",
    save_path="../../../SpatialFusion/results/figures_Fig6/dotplot_malignant_subset_deg.svg",
    vmin=-0.25, vmax=0.25, dot_max=0.3, figsize=(11, 3),
)


In [None]:
subdf =subadata.obs.copy()
# --- Run demo if user df not present ---
fig, ax, counts_df, props_df = alluvial_multi_groups(
    subdf,
    group_col="leiden_joint",
    subtype_col="refined_cellsubtypes",
    palette=palette_dict_2["refined_cellsubtypes"],
    normalize=True,
    title="Malignant cell composition",
    group_order=['10','16'],
    #group_order=[f'{i}' for i in range(18)],
    figsize=(10,3),
    outfile_prefix="../../../SpatialFusion/results/figures_Fig6/malignant_alluvial.svg",
)

In [None]:
subdf =subadata[subadata.obs.Histology=='Adenocarcinoma'].obs.copy()

# --- Run demo if user df not present ---
fig, ax, counts_df, props_df = alluvial_multi_groups(
    subdf,
    group_col="leiden_joint",
    subtype_col="refined_cellsubtypes",
    palette=palette_dict_2["refined_cellsubtypes"],
    normalize=True,
    title="Malignant cell composition",
    group_order=['10','16'],
    #group_order=[f'{i}' for i in range(18)],
    figsize=(10,3),
    outfile_prefix="../../../SpatialFusion/results/figures_Fig6/malignant_LUAD_alluvial.svg",
)

In [None]:
subdf =subadata[subadata.obs.Histology=='Squamous Carcinoma'].obs.copy()

# --- Run demo if user df not present ---
fig, ax, counts_df, props_df = alluvial_multi_groups(
    subdf,
    group_col="leiden_joint",
    subtype_col="refined_cellsubtypes",
    palette=palette_dict_2["refined_cellsubtypes"],
    normalize=True,
    title="Malignant cell composition",
    group_order=['10','16'],
    #group_order=[f'{i}' for i in range(18)],
    figsize=(10,3),
    outfile_prefix="../../../SpatialFusion/results/figures_Fig6/malignant_LUSC_alluvial.svg",
)