In [None]:
import scanpy as sc
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import pathlib as pl
import math
import tifffile

from matplotlib import patheffects

# Import and preprocess data

## Helper functions

In [None]:
def build_palettes_from_adata(adata, palette_specs):
    """
    Build labeled color palettes for categorical columns in adata.obs.

    Parameters
    ----------
    adata : AnnData
        Must have .obs DataFrame containing categorical columns.
    palette_specs : dict
        Mapping {column_name: palette} where palette can be:
          - a string palette name (e.g. "tab10")
          - a list of RGB colors (custom)

    Returns
    -------
    dict
        {column_name: {label: color}} mapping.
    """
    custom_palettes = {}

    for col, palette in palette_specs.items():
        if col not in adata.obs.columns:
            print(f"⚠️ Warning: '{col}' not found in adata.obs — skipping.")
            continue

        unique_vals = sorted(adata.obs[col].astype(str).dropna().unique())
        n_unique = len(unique_vals)

        # If user passed a name → generate via seaborn
        if isinstance(palette, str):
            pal_colors = sns.color_palette(palette, n_colors=n_unique)
        # If user passed a list → use directly
        elif isinstance(palette, (list, tuple)):
            pal_colors = palette[:n_unique]
        else:
            raise ValueError(f"Unsupported palette type for '{col}': {type(palette)}")

        color_dict = dict(zip(unique_vals, pal_colors))
        custom_palettes[col] = color_dict

    print(f"✅ Built palettes for {len(custom_palettes)} columns.")
    return custom_palettes


def plot_celltype_spatial_single_split_legend(
    df,
    color_by="celltype",
    sample_id=None,
    title=None,
    palette_dict=None,         # ✅ added
    palette_name="tab20",
    s=1.5,
    save_svg=True,
    output_prefix="spatial_plot",
    legend_title=None,
):
    """
    Nature Genetics–style spatial scatterplot for one sample,
    saving main plot as PNG (raster) and legend separately as SVG (vector).
    """
    sns.set_style("white")
    sns.set_context("talk")

    # --- Subset one sample ---
    if sample_id is not None:
        df = df[df["sample_id"] == sample_id].copy()
        if df.empty:
            raise ValueError(f"Sample ID '{sample_id}' not found in DataFrame.")

    # --- Colors ---
    unique_labels = sorted(df[color_by].dropna().unique())
    if palette_dict is not None and color_by in palette_dict:
        color_dict = palette_dict[color_by]
    else:
        palette = sns.color_palette(palette_name, n_colors=len(unique_labels))
        color_dict = dict(zip(unique_labels, palette))

    # --- Main plot ---
    fig, ax = plt.subplots(figsize=(6, 5), dpi=300)
    sns.scatterplot(
        data=df,
        x="X_coord", y="Y_coord",
        hue=color_by, palette=color_dict,
        s=s, alpha=0.9, linewidth=0,
        rasterized=True, ax=ax, legend=False
    )
    ax.invert_yaxis(); ax.set_aspect("equal", adjustable="box")
    for spine in ["top", "right", "left", "bottom"]:
        ax.spines[spine].set_visible(False)
    ax.set_xticks([]); ax.set_yticks([])
    ax.set_xlabel(""); ax.set_ylabel("")
    plt.tight_layout()

    # --- Save main figure ---
    fname_main = f"{output_prefix}_{sample_id or 'sample'}_main.png"
    fig.savefig(fname_main, dpi=300, bbox_inches="tight", transparent=True, format="png")
    print(f"Saved main figure: {fname_main}")

    # --- Legend ---
    fig_leg, ax_leg = plt.subplots(figsize=(3, 0.5 * len(unique_labels)), dpi=300)
    handles = [
        plt.Line2D([0], [0], marker='o', color='none', label=label,
                   markerfacecolor=color_dict[label], markersize=8)
        for label in unique_labels
    ]
    ax_leg.legend(handles=handles, loc="center left", frameon=False,
                  title=legend_title or color_by, title_fontsize=14, fontsize=14)
    ax_leg.axis("off")
    plt.tight_layout()

    if save_svg:
        fname_leg = f"{output_prefix}_{sample_id or 'sample'}_legend.svg"
        fig_leg.savefig(fname_leg, dpi=300, bbox_inches="tight", transparent=True, format="svg")
        print(f"Saved legend: {fname_leg}")

    plt.close(fig); plt.close(fig_leg)


## Analysis

This is the data preprocessed in the Preprocess-OVCA notebook.

In [None]:
rawdata = sc.read_h5ad('../../../Broad_SpatialFoundation/test_data/10X_Xenium_Ovarian_5k/adata.h5ad')

These region annotations are derived from the provided annotated image on the 10X website. The position of the annotations were exported using QuPath and all cells that fell within the annotated regions were assigned the annotation label. Cells that weren't in a specific annotated region were labeled as unassigned.

In [None]:
region_annot = pd.read_csv('../../../Broad_SpatialFoundation/test_data/10X_Xenium_Ovarian_5k/region_annotations.csv',index_col=0)

rawdata.obs['path_region'] = region_annot.loc[rawdata.obs_names].values.ravel()

rawdata.obs = pd.concat([rawdata.obs, pd.DataFrame(rawdata.obsm['spatial_px'], index=rawdata.obs_names, columns=['X_coord','Y_coord'])],axis=1)

region_df = rawdata.obs[['cell_labels', 'minor_celltype', 'major_celltype', 'cell_id',
       'path_region', 'X_coord','Y_coord']]

region_df['sample_id'] = 'TENXOv5k'

In [None]:
tab_filtered = sns.color_palette()
tab_filtered = [c for i,c in enumerate(tab_filtered) if i not in [4,6]]

tab20_filtered = sns.color_palette('tab20') + sns.color_palette('tab20c')[:11]
tab20_filtered = [c for i,c in enumerate(tab20_filtered) if i not in [8,9,12,13]]

In [None]:
palette_specs = {
            "path_region": tab_filtered,
            "major_celltype": tab_filtered,
            "minor_celltype": tab20_filtered,
        }

palette_dict_1 = build_palettes_from_adata(rawdata, palette_specs)

In [None]:
cpal = sns.color_palette()
palette_dict_1['major_celltype'] = {'Malignant': cpal[0], 'Epithelial': cpal[1], 'Myeloid': cpal[2],
                  'Lymphocytes': cpal[3], 'Stromal': cpal[5], 'Endothelial': cpal[9],
                  'Pericytes': cpal[8], 'Unassigned': cpal[7]}

In [None]:
plot_celltype_spatial_single_split_legend(
    region_df,
    color_by="major_celltype",
    sample_id=None,
    title='OVCA\nCell type',
    palette_dict=palette_dict_1,
    s=1.5,
    save_svg=True,
    output_prefix="../../../SpatialFusion/results/figures_Fig2/OVCA_celltype",
    legend_title='Cell Type'
)

In [None]:
plot_celltype_spatial_single_split_legend(
    region_df,
    color_by="minor_celltype",
    sample_id=None,
    title='OVCA\nCell subtype',
   palette_dict=palette_dict_1,
    s=1.5,
    save_svg=True,
    output_prefix="../../../SpatialFusion/results/figures_Fig2/OVCA_cellsubtype",
    legend_title='Cell Subtype'
)

In [None]:
plot_celltype_spatial_single_split_legend(
    region_df,
    color_by="path_region",
    sample_id=None,
    title='OVCA\nPathologist-annotated region',
    palette_dict=palette_dict_1,
    s=1.5,
    save_svg=True,
    output_prefix="../../../SpatialFusion/results/figures_Fig2/OVCA_pathregion",
    legend_title='Path.-annotated region'
)

## Download estimated pathway activity

In [None]:
pathway_matrix = pd.read_parquet('../../../Broad_SpatialFoundation/test_data/10X_Xenium_Ovarian_5k/pathway_activation.parquet')

# Embed sample

In [None]:
from spatialfusion.embed.embed import AEInputs, run_full_embedding

We compare different variations of the model here: with and without cross-modality losses in the multimodal AE, and using only the H&E aligned mAE input, only the RNA aligned, or the joint input, either concatenated or averaged.

In [None]:
basepath = pl.Path('../../../Broad_SpatialFoundation/test_data/')
sample_name = '10X_Xenium_Ovarian_5k'
output_dir = basepath / sample_name

In [None]:
uni_df = pd.read_csv(pl.Path(output_dir) / 'embeddings' / 'UNI.csv', index_col=0)
scgpt_df = pd.read_csv(pl.Path(output_dir) / 'embeddings' / 'scGPT.csv', index_col=0)

In [None]:
adata = sc.read_h5ad(basepath / sample_name / 'adata.h5ad')
adata.obs = pd.concat([adata.obs, pd.DataFrame(adata.obsm['spatial_px'], index=adata.obs_names, columns=['X_coord','Y_coord'])],axis=1)
adata.obs["sample_id"] = sample_name

In [None]:
ae_inputs_by_sample = {
    sample_name: AEInputs(adata=adata, z_uni=uni_df, z_scgpt=scgpt_df),
}

## Embedding averaged

In [None]:
# this uses the average version
embeddings_df = run_full_embedding(
    ae_inputs_by_sample=ae_inputs_by_sample,
    ae_model_path='../../../Broad_SpatialFoundation/checkpoint_dir_ae/paired_model_6c22d731.pt',
    gcn_model_path='../../../Broad_SpatialFoundation/checkpoint_dir_gcn/gcn_20250828-123835_e926ee8d/model.pt',
    device="cuda:1",
    combine_mode="average",
    spatial_key='spatial_px',
    celltype_key='major_celltype',
    save_ae_dir=None,  # optional
)

In [None]:
out_path = "../../../Broad_SpatialFoundation/test_data/10X_Xenium_Ovarian_5k/gcn_embeddings_new.parquet"
embeddings_df.to_parquet(out_path)

## Embedding concat

In [None]:
# this uses the average version
embeddings_df = run_full_embedding(
    ae_inputs_by_sample=ae_inputs_by_sample,
    ae_model_path='../../../Broad_SpatialFoundation/checkpoint_dir_ae/paired_model_6c22d731.pt',
    gcn_model_path='../../../Broad_SpatialFoundation/checkpoint_dir_gcn/gcn_20251001-101456_3d65e602/model.pt',
    device="cuda:1",
    combine_mode="concat",
    spatial_key='spatial_px',
    celltype_key='major_celltype',
    save_ae_dir=None,  # optional
)

In [None]:
out_path = "../../../Broad_SpatialFoundation/test_data/10X_Xenium_Ovarian_5k/gcn_embeddings_concat_new.parquet"
embeddings_df.to_parquet(out_path)

## Embedding with just H&E

In [None]:
# this uses the average version
embeddings_df = run_full_embedding(
    ae_inputs_by_sample=ae_inputs_by_sample,
    ae_model_path='../../../Broad_SpatialFoundation/checkpoint_dir_ae/paired_model_6c22d731.pt',
    gcn_model_path='../../../Broad_SpatialFoundation/checkpoint_dir_gcn/gcn_20251001-102239_fa6fe395/model.pt',
    device="cuda:1",
    combine_mode="z1",
    spatial_key='spatial_px',
    celltype_key='major_celltype',
    save_ae_dir=None,  # optional
)

In [None]:
out_dir = pl.Path('../../../Broad_SpatialFoundation/test_data/10X_Xenium_Ovarian_5k/')
out_path = out_dir / f"gcn_embeddings_he_only_new.parquet"
embeddings_df.to_parquet(out_path)

## Embedding with RNA only

In [None]:
# this uses the average version
embeddings_df = run_full_embedding(
    ae_inputs_by_sample=ae_inputs_by_sample,
    ae_model_path='../../../Broad_SpatialFoundation/checkpoint_dir_ae/paired_model_6c22d731.pt',
    gcn_model_path='../../../Broad_SpatialFoundation/checkpoint_dir_gcn/gcn_20251001-102510_058c13d3/model.pt',
    device="cuda:1",
    combine_mode="z2",
    spatial_key='spatial_px',
    celltype_key='major_celltype',
    save_ae_dir=None,  # optional
)

In [None]:
out_dir = pl.Path('../../../Broad_SpatialFoundation/test_data/10X_Xenium_Ovarian_5k/')
out_path = out_dir / f"gcn_embeddings_rna_only_new.parquet"
embeddings_df.to_parquet(out_path)

## Embedding with only recon average

In [None]:
# this uses the average version
embeddings_df = run_full_embedding(
    ae_inputs_by_sample=ae_inputs_by_sample,
    ae_model_path='../../../Broad_SpatialFoundation/checkpoint_dir_ae/paired_model_c036a288.pt',
    gcn_model_path='../../../Broad_SpatialFoundation/checkpoint_dir_gcn/gcn_20250828-123835_e926ee8d/model.pt',
    device="cuda:1",
    combine_mode="average",
    spatial_key='spatial_px',
    celltype_key='major_celltype',
    save_ae_dir=None,  # optional
)

In [None]:
out_dir = pl.Path('../../../Broad_SpatialFoundation/test_data/10X_Xenium_Ovarian_5k/')
out_path = out_dir / f"gcn_embeddings_onlyrecon_new.parquet"
embeddings_df.to_parquet(out_path)

## Embedding with only recon concat

In [None]:
# this uses the average version
embeddings_df = run_full_embedding(
    ae_inputs_by_sample=ae_inputs_by_sample,
    ae_model_path='../../../Broad_SpatialFoundation/checkpoint_dir_ae/paired_model_c036a288.pt',
    gcn_model_path='../../../Broad_SpatialFoundation/checkpoint_dir_gcn/gcn_20251001-095106_7db86b8f/model.pt',
    device="cuda:1",
    combine_mode="concat",
    spatial_key='spatial_px',
    celltype_key='major_celltype',
    save_ae_dir=None,  # optional
)

In [None]:
out_dir = pl.Path('../../../Broad_SpatialFoundation/test_data/10X_Xenium_Ovarian_5k/')
out_path = out_dir / f"gcn_embeddings_onlyrecon_concat_new.parquet"
embeddings_df.to_parquet(out_path)

## Embedding with only recon, H&E

In [None]:
# this uses the average version
embeddings_df = run_full_embedding(
    ae_inputs_by_sample=ae_inputs_by_sample,
    ae_model_path='../../../Broad_SpatialFoundation/checkpoint_dir_ae/paired_model_c036a288.pt',
    gcn_model_path='../../../Broad_SpatialFoundation/checkpoint_dir_gcn/gcn_20251001-101143_444107d5/model.pt',
    device="cuda:1",
    combine_mode="z1",
    spatial_key='spatial_px',
    celltype_key='major_celltype',
    save_ae_dir=None,  # optional
)

In [None]:
out_dir = pl.Path('../../../Broad_SpatialFoundation/test_data/10X_Xenium_Ovarian_5k/')
out_path = out_dir / f"gcn_embeddings_onlyrecon_he_new.parquet"
embeddings_df.to_parquet(out_path)

## Embedding with only recon, RNA

In [None]:
# this uses the average version
embeddings_df = run_full_embedding(
    ae_inputs_by_sample=ae_inputs_by_sample,
    ae_model_path='../../../Broad_SpatialFoundation/checkpoint_dir_ae/paired_model_c036a288.pt',
    gcn_model_path='../../../Broad_SpatialFoundation/checkpoint_dir_gcn/gcn_20251001-100041_c9d24ac4/model.pt',
    device="cuda:1",
    combine_mode="z2",
    spatial_key='spatial_px',
    celltype_key='major_celltype',
    save_ae_dir=None,  # optional
)

In [None]:
out_dir = pl.Path('../../../Broad_SpatialFoundation/test_data/10X_Xenium_Ovarian_5k/')
out_path = out_dir / f"gcn_embeddings_onlyrecon_rna_new.parquet"
embeddings_df.to_parquet(out_path)

# SDMBench

## Helper functions

In [None]:
import numpy as np
import pandas as pd
from tqdm import tqdm

import scanpy as sc

from scipy.spatial import distance
from sklearn.neighbors import NearestNeighbors
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import *


In [None]:
def compute_PAS_fast(clusterlabel, location, k=10):
    clusterlabel = np.array(clusterlabel)
    location = np.array(location)

    # Fit NearestNeighbors (ignore self-match later)
    nbrs = NearestNeighbors(n_neighbors=k+1, algorithm='auto').fit(location)
    distances, indices = nbrs.kneighbors(location)

    # Remove self (first column is self in most cases)
    neighbor_indices = indices[:, 1:]  # shape: (n_samples, k)

    # Check PAS condition
    mismatches = np.array([
        np.sum(clusterlabel[neighbor_indices[i]] != clusterlabel[i]) > (k / 2)
        for i in range(len(clusterlabel))
    ])

    return np.sum(mismatches) / len(clusterlabel)


def compute_CHAOS_fast(clusterlabel, location):
    clusterlabel = np.array(clusterlabel)
    location = np.array(location)
    matched_location = StandardScaler().fit_transform(location)

    clusterlabel_unique = np.unique(clusterlabel)
    dist_val = 0
    total_count = 0

    for k in tqdm(clusterlabel_unique, desc="Computing CHAOS"):
        cluster_mask = clusterlabel == k
        location_cluster = matched_location[cluster_mask]
        n = location_cluster.shape[0]

        if n <= 2:
            continue

        # Use NearestNeighbors to find 1-NN distances
        nbrs = NearestNeighbors(n_neighbors=2, algorithm='auto').fit(location_cluster)
        distances, _ = nbrs.kneighbors(location_cluster)

        # distances[:, 0] is zero (self), distances[:, 1] is nearest neighbor
        dist_val += np.sum(distances[:, 1])
        total_count += n

    return dist_val / total_count if total_count > 0 else np.nan


def compute_ASW_fast(adata, pred_key, spatial_key='spatial'):
    coords = adata.obsm[spatial_key]
    labels = adata.obs[pred_key]
    return silhouette_score(X=coords, labels=labels, metric='euclidean')

def compute_ARI(adata,gt_key,pred_key):
        return adjusted_rand_score(adata.obs[gt_key],adata.obs[pred_key])

def compute_NMI(adata,gt_key,pred_key):
    return normalized_mutual_info_score(adata.obs[gt_key],adata.obs[pred_key])

def compute_HOM(adata,gt_key,pred_key):
    return homogeneity_score(adata.obs[gt_key],adata.obs[pred_key])

def compute_COM(adata,gt_key,pred_key):
    return completeness_score(adata.obs[gt_key],adata.obs[pred_key])

## Analysis

In [None]:
adata = rawdata.copy()

In [None]:
# this is to re-read pre-computed clustering 
adata_obs= pd.read_csv('../../../Broad_SpatialFoundation/notebooks/benchmark_ovarian_adata_obs.csv',index_col=0)
adata_obs.index = adata_obs.index.astype(str)

adata.obs = adata_obs

Now we read all the embeddings and add them to the AnnData

In [None]:
emb_df = pd.read_parquet('../../../Broad_SpatialFoundation/test_data/10X_Xenium_Ovarian_5k/gcn_embeddings_new.parquet').set_index('cell_id')

emb_concat_df = pd.read_parquet('../../../Broad_SpatialFoundation/test_data/10X_Xenium_Ovarian_5k/gcn_embeddings_concat_new.parquet').set_index('cell_id')

emb_he_df = pd.read_parquet('../../../Broad_SpatialFoundation/test_data/10X_Xenium_Ovarian_5k/gcn_embeddings_he_only_new.parquet').set_index('cell_id')

emb_rna_df = pd.read_parquet('../../../Broad_SpatialFoundation/test_data/10X_Xenium_Ovarian_5k/gcn_embeddings_rna_only_new.parquet').set_index('cell_id')

emb_onlyrecon_df = pd.read_parquet('../../../Broad_SpatialFoundation/test_data/10X_Xenium_Ovarian_5k/gcn_embeddings_onlyrecon_new.parquet').set_index('cell_id')

emb_onlyrecon_concat_df = pd.read_parquet('../../../Broad_SpatialFoundation/test_data/10X_Xenium_Ovarian_5k/gcn_embeddings_onlyrecon_concat_new.parquet').set_index('cell_id')

emb_onlyrecon_he_df = pd.read_parquet('../../../Broad_SpatialFoundation/test_data/10X_Xenium_Ovarian_5k/gcn_embeddings_onlyrecon_he_new.parquet').set_index('cell_id')

emb_onlyrecon_rna_df = pd.read_parquet('../../../Broad_SpatialFoundation/test_data/10X_Xenium_Ovarian_5k/gcn_embeddings_onlyrecon_rna_new.parquet').set_index('cell_id')

banksy_embeddings = pd.read_parquet('../../../Broad_SpatialFoundation/test_data/10X_Xenium_Ovarian_5k/embeddings/banksy_08.parquet')

nichecompass_embeddings = pd.read_parquet('../../../Broad_SpatialFoundation/test_data/10X_Xenium_Ovarian_5k/embeddings/nichecompass.parquet')

nicheformer_embeddings = pd.read_csv('../../../Broad_SpatialFoundation/test_data/10X_Xenium_Ovarian_5k/embeddings/nicheformer.csv').set_index('cell_id')

scgptspatial_embeddings = pd.read_csv('../../../Broad_SpatialFoundation/test_data/10X_Xenium_Ovarian_5k/embeddings/scGPTspatial.csv',index_col=0)

omiclip_text_embeddings = pd.read_parquet('../../../Broad_SpatialFoundation/test_data/10X_Xenium_Ovarian_5k/OmiCLIP_text_emb.parquet')

omiclip_image_embeddings = pd.read_parquet('../../../Broad_SpatialFoundation/test_data/10X_Xenium_Ovarian_5k/OmiCLIP_image_emb.parquet')

In [None]:
adata.obsm['gcn'] = emb_df.loc[adata.obs_names,['0','1','2','3','4','5','6','7','8','9']]

adata.obsm['gcn_concat'] = emb_concat_df.loc[adata.obs_names,['0','1','2','3','4','5','6','7','8','9']]

adata.obsm['gcn_he'] = emb_he_df.loc[adata.obs_names,['0','1','2','3','4','5','6','7','8','9']]

adata.obsm['gcn_rna'] = emb_rna_df.loc[adata.obs_names,['0','1','2','3','4','5','6','7','8','9']]

adata.obsm['gcn_onlyrecon'] = emb_onlyrecon_df.loc[adata.obs_names,['0','1','2','3','4','5','6','7','8','9']]

adata.obsm['gcn_onlyrecon_concat'] = emb_onlyrecon_concat_df.loc[adata.obs_names,['0','1','2','3','4','5','6','7','8','9']]

adata.obsm['gcn_onlyrecon_he'] = emb_onlyrecon_he_df.loc[adata.obs_names,['0','1','2','3','4','5','6','7','8','9']]

adata.obsm['gcn_onlyrecon_rna'] = emb_onlyrecon_rna_df.loc[adata.obs_names,['0','1','2','3','4','5','6','7','8','9']]

adata.obsm['banksy'] = banksy_embeddings.loc[adata.obs_names]

adata.obsm['nichecompass'] = nichecompass_embeddings.loc[adata.obs_names]

adata.obsm['nicheformer'] = nicheformer_embeddings.loc[adata.obs_names]

adata.obsm['scgptspatial'] = scgptspatial_embeddings.loc[adata.obs_names]

adata.obsm['OmiCLIP_text'] = omiclip_text_embeddings.loc[adata.obs_names]

adata.obsm['OmiCLIP_image'] = omiclip_image_embeddings.loc[adata.obs_names]

### Run clustering

The resolutions are set here to have the closest amount to the set number of clusters (here, 13). Small clusters are grouped together so that all methods have the same amount of clusters for comparison.

In [None]:
# this is to re-read pre-computed clustering 
adata_obs= pd.read_csv('benchmark_ovarian_adata_obs_new.csv',index_col=0)
adata_obs.index = adata_obs.index.astype(str)

adata.obs = adata_obs

In [None]:
sc.pp.neighbors(adata, use_rep = 'gcn')

In [None]:
sc.tl.leiden(adata, resolution=0.1, flavor="igraph", n_iterations=2)

In [None]:
# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)

# Make mapping: old → new (ranked by size)
mapping = {old: str(new) for new, old in enumerate(counts.index)}

# Apply mapping
adata.obs['leiden'] = adata.obs['leiden'].map(mapping).astype('category')

# (optional) sort categories by their new numeric label
adata.obs['leiden'].cat.reorder_categories(sorted(adata.obs['leiden'].cat.categories, key=int))

print("Cluster relabeling done ✅")# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)
display(counts)

In [None]:
adata.obs['leiden_gcn'] = adata.obs.leiden.replace({'13': '12', '14': '12', '15': '12', '16': '12', '17': '12',})

In [None]:
sc.pp.neighbors(adata, use_rep = 'gcn_concat')

In [None]:
sc.tl.leiden(adata, resolution=0.2, flavor="igraph", n_iterations=2)

In [None]:
# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)

# Make mapping: old → new (ranked by size)
mapping = {old: str(new) for new, old in enumerate(counts.index)}

# Apply mapping
adata.obs['leiden'] = adata.obs['leiden'].map(mapping).astype('category')

# (optional) sort categories by their new numeric label
adata.obs['leiden'].cat.reorder_categories(sorted(adata.obs['leiden'].cat.categories, key=int))

print("Cluster relabeling done ✅")# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)
display(counts)

In [None]:
adata.obs['leiden_gcn_concat'] = adata.obs.leiden.replace({'13': '12', '14': '12', '15': '12', '16': '12', '17': '12',})

In [None]:
sc.pp.neighbors(adata, use_rep = 'gcn_he')

In [None]:
sc.tl.leiden(adata, resolution=0.1, flavor="igraph", n_iterations=2)

In [None]:
# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)

# Make mapping: old → new (ranked by size)
mapping = {old: str(new) for new, old in enumerate(counts.index)}

# Apply mapping
adata.obs['leiden'] = adata.obs['leiden'].map(mapping).astype('category')

# (optional) sort categories by their new numeric label
adata.obs['leiden'].cat.reorder_categories(sorted(adata.obs['leiden'].cat.categories, key=int))

print("Cluster relabeling done ✅")# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)
display(counts)

In [None]:
adata.obs['leiden_gcn_he'] = adata.obs.leiden.replace({'13': '12', '14': '12', '15': '12', '16': '12',})

In [None]:
sc.pp.neighbors(adata, use_rep = 'gcn_rna')

In [None]:
sc.tl.leiden(adata, resolution=0.2, flavor="igraph", n_iterations=2)

In [None]:
# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)

# Make mapping: old → new (ranked by size)
mapping = {old: str(new) for new, old in enumerate(counts.index)}

# Apply mapping
adata.obs['leiden'] = adata.obs['leiden'].map(mapping).astype('category')

# (optional) sort categories by their new numeric label
adata.obs['leiden'].cat.reorder_categories(sorted(adata.obs['leiden'].cat.categories, key=int))

print("Cluster relabeling done ✅")# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)
display(counts)

In [None]:
adata.obs['leiden_gcn_rna'] = adata.obs.leiden.replace({'13': '12', '14': '12', '15': '12', '16': '12', '17': '12',
                                                          '18': '12', '19': '12', '20': '12', '21': '12', '22': '12',})

In [None]:
sc.pp.neighbors(adata, use_rep = 'gcn_onlyrecon')

In [None]:
sc.tl.leiden(adata, resolution=0.1, flavor="igraph", n_iterations=2)

In [None]:
# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)

# Make mapping: old → new (ranked by size)
mapping = {old: str(new) for new, old in enumerate(counts.index)}

# Apply mapping
adata.obs['leiden'] = adata.obs['leiden'].map(mapping).astype('category')

# (optional) sort categories by their new numeric label
adata.obs['leiden'].cat.reorder_categories(sorted(adata.obs['leiden'].cat.categories, key=int))

print("Cluster relabeling done ✅")# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)
display(counts)

In [None]:
adata.obs['leiden_gcn_onlyrecon'] = adata.obs.leiden.replace({'13': '12', '14': '12', '15': '12', '16': '12',
                                                              '17': '12', '18': '12', '19': '12',})

In [None]:
sc.pp.neighbors(adata, use_rep = 'gcn_onlyrecon_concat')

In [None]:
sc.tl.leiden(adata, resolution=0.1, flavor="igraph", n_iterations=2)

In [None]:
# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)

# Make mapping: old → new (ranked by size)
mapping = {old: str(new) for new, old in enumerate(counts.index)}

# Apply mapping
adata.obs['leiden'] = adata.obs['leiden'].map(mapping).astype('category')

# (optional) sort categories by their new numeric label
adata.obs['leiden'].cat.reorder_categories(sorted(adata.obs['leiden'].cat.categories, key=int))

print("Cluster relabeling done ✅")# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)
display(counts)

In [None]:
adata.obs['leiden_gcn_onlyrecon_concat'] = adata.obs.leiden.replace({'13': '12', '14': '12', '15': '12', '16': '12',
                                                                     '17': '12', '18': '12', '19': '12', '20': '12',})

In [None]:
sc.pp.neighbors(adata, use_rep = 'gcn_onlyrecon_he')

In [None]:
sc.tl.leiden(adata, resolution=0.1, flavor="igraph", n_iterations=2)

In [None]:
# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)

# Make mapping: old → new (ranked by size)
mapping = {old: str(new) for new, old in enumerate(counts.index)}

# Apply mapping
adata.obs['leiden'] = adata.obs['leiden'].map(mapping).astype('category')

# (optional) sort categories by their new numeric label
adata.obs['leiden'].cat.reorder_categories(sorted(adata.obs['leiden'].cat.categories, key=int))

print("Cluster relabeling done ✅")# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)
display(counts)

In [None]:
adata.obs['leiden_gcn_onlyrecon_he'] = adata.obs.leiden.replace({'13': '12', '14': '12', '15': '12', '16': '12',})

In [None]:
sc.pp.neighbors(adata, use_rep = 'gcn_onlyrecon_rna')

In [None]:
sc.tl.leiden(adata, resolution=0.2, flavor="igraph", n_iterations=2)

In [None]:
# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)

# Make mapping: old → new (ranked by size)
mapping = {old: str(new) for new, old in enumerate(counts.index)}

# Apply mapping
adata.obs['leiden'] = adata.obs['leiden'].map(mapping).astype('category')

# (optional) sort categories by their new numeric label
adata.obs['leiden'].cat.reorder_categories(sorted(adata.obs['leiden'].cat.categories, key=int))

print("Cluster relabeling done ✅")# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)
display(counts)

In [None]:
adata.obs['leiden_gcn_onlyrecon_rna'] = adata.obs.leiden.replace({'13': '12', '14': '12', '15': '12', '16': '12',
                                                                 '17': '12', '18': '12', '19': '12', '20': '12',})

In [None]:
sc.pp.neighbors(adata, use_rep = 'banksy')

In [None]:
sc.tl.leiden(adata, resolution=0.5, flavor="igraph", n_iterations=2)

In [None]:
# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)

# Make mapping: old → new (ranked by size)
mapping = {old: str(new) for new, old in enumerate(counts.index)}

# Apply mapping
adata.obs['leiden'] = adata.obs['leiden'].map(mapping).astype('category')

# (optional) sort categories by their new numeric label
adata.obs['leiden'].cat.reorder_categories(sorted(adata.obs['leiden'].cat.categories, key=int))

print("Cluster relabeling done ✅")
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)
display(counts)

In [None]:
adata.obs['leiden_banksy_08'] = adata.obs.leiden.replace({'13': '12', '14': '12', '15': '12', '16': '12', '17': '12',})

In [None]:
## RUN ONLY ONCE!
sc.pp.neighbors(adata, use_rep='OmiCLIP_text')

sc.tl.leiden(adata, resolution=0.6)

adata.obs.leiden.value_counts()

adata.obs['leiden_OmiCLIP_text'] = adata.obs.leiden

sc.pp.neighbors(adata, use_rep='nichecompass')

sc.tl.leiden(adata, resolution=0.2)

adata.obs.leiden.value_counts()

sc.pp.neighbors(adata, use_rep='OmiCLIP_image')

sc.tl.leiden(adata, resolution=0.3)

adata.obs['leiden_OmiCLIP_image'] = adata.obs.leiden.replace({'13': '12', '14': '12',})

adata.obs['leiden_nichecompass'] = adata.obs.leiden

sc.pp.neighbors(adata, use_rep='nicheformer')

sc.tl.leiden(adata, resolution=0.5)

adata.obs.leiden.value_counts()

adata.obs['leiden_nicheformer'] = adata.obs.leiden

sc.pp.neighbors(adata, use_rep='scgptspatial')

sc.tl.leiden(adata, resolution=0.3)

adata.obs.leiden.value_counts()

adata.obs['leiden_scgptspatial'] = adata.obs.leiden

sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)

sc.tl.pca(adata)

sc.pp.neighbors(adata, use_rep='X_pca')

sc.tl.leiden(adata, resolution=1)

adata.obs.leiden.value_counts()

adata.obs['leiden_scanpy'] = adata.obs.leiden

In [None]:
adata.obs.to_csv('benchmark_ovarian_adata_obs_new.csv')

## Compute metrics

In [None]:
def compute_all_metrics(adata, clustering_keys, ground_truth_key='path_region', spatial_key='spatial_px'):
    results = {}

    for method_name, cluster_key in clustering_keys.items():
        metrics = {
            'ARI': compute_ARI(adata, cluster_key, ground_truth_key),
            'NMI': compute_NMI(adata, cluster_key, ground_truth_key),
            'HOM': compute_HOM(adata, cluster_key, ground_truth_key),
            'COM': compute_COM(adata, cluster_key, ground_truth_key),
            'PAS': compute_PAS_fast(adata.obs[cluster_key], adata.obsm[spatial_key]),
            'CHAOS': compute_CHAOS_fast(adata.obs[cluster_key], adata.obsm[spatial_key]),
        }
        results[method_name] = metrics

    return pd.DataFrame(results)

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
from matplotlib.colors import LinearSegmentedColormap

def format_number(value):
    """Format numbers: scientific notation if <0.01, else 2 decimals."""
    if pd.isna(value):
        return ""
    if abs(value) < 0.01 and value != 0:
        return f"{value:.0e}"  # 1 decimal in scientific notation, e.g. 3.4e-04
    else:
        return f"{value:.2f}"  # two decimals otherwise

def plot_benchmark_heatmap(
    results_df,
    title="Spatial clustering benchmark",
    savefig=None,
    metric_order=None,
):
    """
    Nature Genetics–style benchmarking heatmap showing method rankings across metrics.
    Allows manual control of metric order.
    """

    lower_better = {'PAS', 'CHAOS'}

    # --- Default metric order ---
    if metric_order is None:
        metric_order = list(results_df.index)

    # --- Normalize scores ---
    df_norm = results_df.copy()
    for metric in df_norm.index:
        vals = df_norm.loc[metric]
        if metric in lower_better:
            vals = -vals
        df_norm.loc[metric] = (vals - vals.min()) / (vals.max() - vals.min() + 1e-9)

    # --- Rank per metric ---
    ranks = results_df.copy()
    for metric in ranks.index:
        ranks.loc[metric] = results_df.loc[metric].rank(ascending=(metric in lower_better))

    # --- Prepare longform for plotting ---
    df_plot = df_norm.reset_index().melt(
        id_vars='index', var_name='Method', value_name='Normalized'
    ).rename(columns={'index': 'Metric'})

    df_plot['Raw'] = results_df.reset_index().melt(
        id_vars='index', var_name='Method', value_name='Raw'
    )['Raw']

    df_plot['Rank'] = ranks.reset_index().melt(
        id_vars='index', var_name='Method', value_name='Rank'
    )['Rank']

    # Add directional arrows
    df_plot['MetricLabel'] = df_plot['Metric'].apply(
        lambda m: f"{m} {'↓' if m in lower_better else '↑'}"
    )

    # --- Construct ordered MetricLabel list ---
    metric_order_labels = []
    for m in metric_order:
        arrow = '↓' if m in lower_better else '↑'
        metric_order_labels.append(f"{m} {arrow}")

    # --- Heatmap data matrix ---
    method_order = results_df.columns.tolist()
    df_matrix = df_plot.pivot_table(
        index="MetricLabel", columns="Method", values="Normalized"
    ).loc[metric_order_labels, method_order]

    # --- Aesthetics ---
    sns.set_theme(style="white", context="talk")

    fig, ax = plt.subplots(figsize=(1.3 * len(method_order), 0.8 * len(metric_order)), dpi=300)
    # Enhance contrast near the top (gamma correction)
    gamma = 3  ### THIS IS ONLY FOR THE COLOR FOR PLOTTING PURPOSES, NOT THE NUMBERS!
    df_matrix_contrast = df_matrix ** gamma
    sns.heatmap(
        df_matrix_contrast,
        #cmap="vlag",
        cmap = LinearSegmentedColormap.from_list(
            "vlag_red",
            ["#fee8ef",  # very light pink
             "#f4a3a8",  # pastel red
             "#d95858",  # mid red
             "#b40426"]  # vlag red (vivid crimson)
        ),
        cbar=False,
        ax=ax,
        linewidths=0,
        square=True,
    )

    # --- Adaptive text color (white on dark, black on light) ---
    #cmap = plt.get_cmap("vlag")
    cmap = LinearSegmentedColormap.from_list(
        "vlag_red",
        ["#fee8ef",  # very light pink
         "#f4a3a8",  # pastel red
         "#d95858",  # mid red
         "#b40426"]  # vlag red (vivid crimson)
    )

    for i, metric in enumerate(df_matrix.index):
        base_metric = metric.split()[0]
        for j, method in enumerate(df_matrix.columns):
            raw_val = results_df.loc[base_metric, method]
            norm_val = df_matrix.loc[metric, method]

            # Compute luminance for adaptive color
            rgb = np.array(cmap(norm_val)[:3])
            luminance = 0.2126 * rgb[0] + 0.7152 * rgb[1] + 0.0722 * rgb[2]
            text_color = "black" if luminance > 0.5 else "white"

            ax.text(
                j + 0.5, i + 0.5,
                format_number(raw_val),
                ha='center', va='center',
                color=text_color,
                fontsize=8,
                fontweight='normal',
            )

    # --- Formatting ---
    ax.set_title(title, fontsize=10, pad=14, fontweight='normal')
    ax.set_xlabel("")
    ax.set_ylabel("")
    ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right", fontsize=10, fontweight='normal')
    ax.set_yticklabels(ax.get_yticklabels(), fontsize=10, fontweight='normal')

    for spine in ax.spines.values():
        spine.set_visible(False)

    plt.tight_layout()

    if savefig:
        fig.savefig(
            savefig,
            bbox_inches="tight",
            dpi=300,
            format=savefig.split('.')[-1],
            transparent=True
        )
        print(f"Saved: {savefig}")

    plt.show()



In [None]:
# this is to re-read pre-computed clustering 
#adata_obs= pd.read_csv('benchmark_ovarian_adata_obs.csv',index_col=0)
adata_obs= pd.read_csv('benchmark_ovarian_adata_obs_new.csv',index_col=0)
adata_obs.index = adata_obs.index.astype(str)

adata.obs = adata_obs

In [None]:
# adding in the pathway PROGENy estimates
adata.obs = pd.concat([adata.obs,pathway_matrix],axis=1)

In [None]:
clustering_keys = {
    'SpatialFusion': 'leiden_gcn',
    'SpatialFusion (concat)': 'leiden_gcn_concat',
    'SpatialFusion (H&E)': 'leiden_gcn_he',
    'SpatialFusion (RNA)': 'leiden_gcn_rna',
    'SpatialFusion (recon)': 'leiden_gcn_onlyrecon',
    'SpatialFusion (recon concat)': 'leiden_gcn_onlyrecon_concat',
    'SpatialFusion (recon H&E)': 'leiden_gcn_onlyrecon_he',
    'SpatialFusion (recon RNA)': 'leiden_gcn_onlyrecon_rna',
    'NicheCompass': 'leiden_nichecompass',
    'BANKSY': 'leiden_banksy_08',
    'Nicheformer': 'leiden_nicheformer',
    'scGPT-spatial': 'leiden_scgptspatial',
    'OmiCLIP text': 'leiden_OmiCLIP_text',
    'OmiCLIP image': 'leiden_OmiCLIP_image',
    'Scanpy': 'leiden_scanpy',
}

results_df = compute_all_metrics(adata, clustering_keys)


In [None]:
import matplotlib
matplotlib.rcParams['svg.fonttype'] = 'none'
plot_benchmark_heatmap(results_df, title="OVCA Benchmark", savefig='../../../SpatialFusion/results/figures_Fig2/NEW_OVCA_benchmark.svg')

## Plot the clusters in space

In [None]:
def plot_spatial_clusters_panel(
    adata,
    method_mapping,
    color_dict,
    coord_keys=("X_coord", "Y_coord"),
    ncols=5,
    savefig=None,
    rasterize_points=True,
):
    """
    Plot spatial clustering panels for multiple methods in a Nature Genetics style.
    """

    x_key, y_key = coord_keys
    method_keys = list(method_mapping.values())
    method_titles = list(method_mapping.keys())
    n_methods = len(method_keys)
    nrows = math.ceil(n_methods / ncols)

    # --- Shared color palette across all cluster IDs (as strings) ---
    all_labels = np.unique(
        np.concatenate([
            adata.obs[k].astype(str).values for k in method_keys
        ])
    )

    # --- Figure style ---
    sns.set_style("white")
    sns.set_context("talk", font_scale=1.3)

    fig, axes = plt.subplots(
        nrows=nrows, ncols=ncols, figsize=(5.5 * ncols, 5 * nrows), dpi=300
    )
    axes = np.array(axes).reshape(-1)

    for i, (display_name, method) in enumerate(method_mapping.items()):
        ax = axes[i]

        # --- Convert hue column to string to match color_dict keys ---
        hue_values = adata.obs[method].astype(str)

        sns.scatterplot(
            x=adata.obs[x_key],
            y=adata.obs[y_key],
            hue=hue_values,
            palette=color_dict[method],
            s=1,
            linewidth=0,
            alpha=0.9,
            ax=ax,
            legend=False,
            rasterized=rasterize_points,
        )

        ax.invert_yaxis()
        ax.set_aspect("equal")

        # --- Titles & styling ---
        ax.set_title(display_name, fontsize=14, fontweight="normal", pad=10)
        ax.set_xlabel("", fontsize=18, labelpad=8, fontweight="normal")
        ax.set_ylabel("", fontsize=18, labelpad=8, fontweight="normal")
        ax.set_xticks([]); ax.set_yticks([])
        for spine in ["top", "right", "left", "bottom"]:
            ax.spines[spine].set_visible(False)
        ax.grid(False)

    # Hide any unused subplots
    for j in range(i + 1, len(axes)):
        axes[j].set_visible(False)

    plt.tight_layout()

    # --- Shared legend ---
    rep_method = list(color_dict.keys())[0]
    rep_palette = color_dict[rep_method]
    
    handles = []
    for label in sorted(all_labels):
        if label in rep_palette:
            handles.append(
                plt.Line2D(
                    [0], [0],
                    marker="o",
                    color="none",
                    markerfacecolor=rep_palette[label],
                    markersize=8,
                    label=label,
                )
            )
    
    legend_fig = plt.figure(figsize=(2.5, 0.4 * len(handles)), dpi=300)
    legend_fig.legend(
        handles=handles,
        loc="center",
        title="Cluster",
        frameon=False,
        ncol=1,
        fontsize=12,
        title_fontsize=14,
    )
    legend_fig.tight_layout()
    
    if savefig:
        fig.savefig(
            f"{savefig}_panel.png",
            dpi=200,
            bbox_inches="tight",
            transparent=True,
        )
        legend_fig.savefig(
            f"{savefig}_legend.svg",
            dpi=300,
            bbox_inches="tight",
            transparent=True,
        )
        print(f"Saved: {savefig}_panel.png and {savefig}_legend.svg")
    
    plt.show()
    plt.close(legend_fig)

In [None]:
leiden_cols = [
       'leiden_gcn', 'leiden_gcn_he', 'leiden_banksy_08',
       'leiden_nichecompass', 'leiden_nicheformer', 'leiden_scgptspatial',
       'leiden_scanpy', 'leiden_OmiCLIP_text',
       'leiden_OmiCLIP_image',
       'leiden_gcn_onlyrecon', 'leiden_gcn_onlyrecon_tied', 
       'leiden_gcn_concat', 'leiden_gcn_rna', 'leiden_gcn_onlyrecon_concat',
       'leiden_gcn_onlyrecon_he', 'leiden_gcn_onlyrecon_rna']

In [None]:
palette_specs = {
            l: tab20_filtered for l in leiden_cols
        }

palette_dict_2 = build_palettes_from_adata(adata, palette_specs)

In [None]:
method_keys = list(clustering_keys.values())

plot_spatial_clusters_panel(
    adata,
    color_dict=palette_dict_2,
    method_mapping=clustering_keys,
    ncols=5,
    savefig='../../../SpatialFusion/results/figures_Fig2/panel_viz_clusters_NEW.svg',
)


# Downstream analysis

## Helper functions|

In [None]:
def _transform_x(aff_transf: pd.DataFrame, coords: np.ndarray) -> np.ndarray:
    """Why do we need this? The H&E image is not naturally aligned to the Xenium output. This can be done through the
    Xenium
    """

    inv_transf = np.linalg.inv(aff_transf)
    transformed_coords = (inv_transf @ np.vstack((coords.T, np.ones(len(coords))))).T[
        :, :-1
    ]

    return transformed_coords
    
# Alignment matrix from 10X
M = np.array([
    [0.010908748623278200,  1.2895248946320600, -721.007456942807],
    [-1.2895248946320600,  0.010908748623278200, 38642.677876412400],
    [0, 0, 1]
])

In [None]:
def plot_annotation(
    ax,
    adata,
    column,
    title,
    palette,
    vmin=-5,
    vmax=5,
    x_key="X_he",
    y_key="Y_he",
    point_size=3,
    xlim=None,
    ylim=None,
    colorbar_info=None,
    legends_info=None,
    rasterize_points=True,
):
    """
    Plot an annotation layer (continuous or categorical) in Nature Genetics style.

    Parameters
    ----------
    ax : matplotlib Axes
        Axis to draw on.
    adata : AnnData
        AnnData object containing .obs[column] and coordinates.
    column : str
        Column name in adata.obs to visualize.
    title : str
        Panel title to display.
    palette : str, dict, or colormap
        Palette for categorical data or cmap for continuous data.
    vmin, vmax : float
        Limits for continuous color scaling (clipping).
    x_key, y_key : str
        Keys for spatial coordinates.
    point_size : float
        Scatter point size.
    xlim, ylim : tuple or None
        Manual axis limits if needed.
    colorbar_info : list
        Collects (scatter, title) tuples for later global colorbar plotting.
    legends_info : list
        Collects (title, handles, labels) tuples for separate legends.
    rasterize_points : bool
        Rasterize scatter for smaller vector file size.
    """

    # --- verify column
    if column not in adata.obs.columns:
        raise KeyError(f"Column '{column}' not found in adata.obs")

    values = adata.obs[column]
    if isinstance(values, pd.DataFrame):
        raise ValueError(f"Column '{column}' is not unique in adata.obs (multiple matches)")

    # --- style
    sns.set_style("white")

    # === Continuous variable ===
    if pd.api.types.is_numeric_dtype(values):
        clipped_values = values.clip(lower=vmin, upper=vmax)

        sc = ax.scatter(
            adata.obs[x_key],
            adata.obs[y_key],
            c=clipped_values.values,
            cmap=palette,
            s=point_size,
            alpha=0.8,
            linewidth=0,
            vmin=vmin,
            vmax=vmax,
            rasterized=rasterize_points,
        )

        if colorbar_info is not None:
            colorbar_info.append((sc, title))

    # === Categorical variable ===
    else:
        hue_vals = values.astype(str)

        sns.scatterplot(
            x=adata.obs[x_key],
            y=adata.obs[y_key],
            hue=hue_vals,
            palette=palette,
            s=point_size,
            linewidth=0,
            alpha=0.9,
            ax=ax,
            legend=False,
            rasterized=rasterize_points,
        )

        # collect legend entries for later legend-only figure
        handles, labels = ax.get_legend_handles_labels()
        by_label = {}
        for h, l in zip(handles, labels):
            if l and l != "_nolegend_" and l not in by_label:
                by_label[l] = h
        if legends_info is not None:
            legends_info.append((title, list(by_label.values()), list(by_label.keys())))

    # --- general aesthetics ---
    ax.set_title(title, pad=10)
    ax.set_aspect("equal")
    ax.invert_yaxis()

    if xlim is not None:
        ax.set_xlim(xlim)
    if ylim is not None:
        ax.set_ylim(ylim)

    # no axes, ticks, or spines
    ax.set_xticks([])
    ax.set_yticks([])
    for spine in ["top", "right", "left", "bottom"]:
        ax.spines[spine].set_visible(False)
    ax.grid(False)


In [None]:
def plot_annotation_with_HE(ax, column, title, palette):
    """Plot H&E ROI + spatial overlay for one annotation column."""
    hue_vals = adata.obs[column].astype(str)

    # --- capture legend handles (temporary axis) ---
    tmp_ax = plt.figure().add_subplot(111)
    sns.scatterplot(
        x=adata.obs["X_he"], y=adata.obs["Y_he"],
        hue=hue_vals, palette=palette,
        s=point_size, ax=tmp_ax, linewidth=0, alpha=0.9, legend=True,
    )
    handles, labels = tmp_ax.get_legend_handles_labels()
    plt.close(tmp_ax.figure)

    # Deduplicate legend entries
    by_label = {l: h for h, l in zip(handles, labels) if l and l != "_nolegend_"}
    legends_info.append((title, list(by_label.values()), list(by_label.keys())))

    # --- Plot the H&E ROI background ---
    ax.imshow(roi, origin="upper", extent=(x0, x1, y1, y0))  # match tissue coordinates

    # --- Overlay scatter (rasterized for small .svg size) ---
    sns.scatterplot(
        x=adata.obs["X_he"], y=adata.obs["Y_he"],
        hue=hue_vals, palette=palette,
        s=point_size, ax=ax, linewidth=0,
        alpha=0.8, legend=False, rasterized=True,
    )

    # --- Axis and aesthetic setup ---
    ax.set_xlim(x0, x1)
    ax.set_ylim(y1, y0)   # keep consistent orientation
    ax.invert_yaxis()
    ax.set_xlabel('')
    ax.set_ylabel('')
    ax.set_aspect("equal")
    ax.set_title(title, fontsize=16, fontweight="normal", pad=6)
    ax.set_xticks([]); ax.set_yticks([])

    for spine in ["top", "right", "left", "bottom"]:
        ax.spines[spine].set_visible(False)
    ax.grid(False)

In [None]:
def plot_annotation_wo_HE(ax, column, title, palette):
    """Single Nature Genetics–style scatter panel."""
    hue_vals = adata.obs[column].astype(str)

    # First call with legend=True on a hidden temporary axis to grab handles
    tmp_ax = plt.figure().add_subplot(111)
    sns.scatterplot(
        x=adata.obs["X_he"],
        y=adata.obs["Y_he"],
        hue=hue_vals,
        palette=palette,
        s=point_size,
        ax=tmp_ax,
        linewidth=0,
        alpha=0.9,
        legend=True,
    )
    handles, labels = tmp_ax.get_legend_handles_labels()
    plt.close(tmp_ax.figure)

    # Deduplicate and clean labels
    by_label = {}
    for h, l in zip(handles, labels):
        if l and l != "_nolegend_" and l not in by_label:
            by_label[l] = h
    legends_info.append((title, list(by_label.values()), list(by_label.keys())))

    # Actual subplot (no legend, rasterized)
    sns.scatterplot(
        x=adata.obs["X_he"],
        y=adata.obs["Y_he"],
        hue=hue_vals,
        palette=palette,
        s=point_size,
        ax=ax,
        linewidth=0,
        alpha=0.9,
        legend=False,
        rasterized=True,
    )

    ax.invert_yaxis()
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    ax.set_xlabel('')
    ax.set_ylabel('')
    ax.set_aspect("equal")

    ax.set_title(title, pad=6, fontweight="normal")
    ax.set_xticks([]); ax.set_yticks([])
    for spine in ["top", "right", "left", "bottom"]:
        ax.spines[spine].set_visible(False)
    ax.grid(False)


In [None]:
source_image_path = '../../../Broad_SpatialFoundation/test_data/10X_Xenium_Ovarian_5k/Xenium_Prime_Ovarian_Cancer_FFPE_XRrun_he_image.ome.tif'

with tifffile.TiffFile(source_image_path) as tif:
    wsi = tif.series[0].asarray()

coords = adata.obsm["spatial_px"]
cell_names = adata.obs_names.to_numpy()

transformed_coords = _transform_x(aff_transf=M, coords=coords)
# Clip at 0 bc sometimes the transformation bugs a little bit, this should be minor though
# (ex: 1 of 150,000 cells had this in a dataset I am evaluating)
print(
    f"There are {((transformed_coords<0).sum(axis=1)>0).sum()} cells with negative coordinates, clipping at 0."
)
transformed_coords = transformed_coords.clip(0)


adata.obsm['spatial_he'] = transformed_coords

adata.obs['X_he'] = adata.obsm['spatial_he'][:,0]
adata.obs['Y_he'] = adata.obsm['spatial_he'][:,1]

In [None]:
adata.obs.columns

In [None]:
fig, ax = plt.subplots(1,1, figsize=(10,10))
sns.scatterplot(
    x=adata.obs["X_he"],
    y=adata.obs["Y_he"],
    #hue=adata.obs['leiden_gcn'].astype(str),
    #palette=palette_dict_2['leiden_gcn'],
    hue=adata.obs['path_region'].astype(str),
    palette=palette_dict_1['path_region'],
    s=4,
    ax=ax,
    linewidth=0,
    alpha=0.9,
    legend=False,
)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np

# === Parameters ===
xlim = (7500, 17500)
ylim = (28000, 38000)
point_size = 3
figsize = (14, 15)
save_prefix = "../../../SpatialFusion/results/figures_Fig2/OVCA_fallopian_zoom"

plot_configs = [
    ("leiden_gcn", "SpatialFusion", palette_dict_2["leiden_gcn"]),
    ("leiden_nichecompass", "NicheCompass", palette_dict_2["leiden_nichecompass"]),
    ("leiden_gcn_he", "SpatialFusion (H&E)", palette_dict_2["leiden_gcn_he"]),
    ("path_region", "Pathologist Region", palette_dict_1["path_region"]),
    ("major_celltype", "Major Cell Type", palette_dict_1["major_celltype"]),
    ("minor_celltype", "Minor Cell Type", palette_dict_1["minor_celltype"]),
]


fig, axes = plt.subplots(3, 2, figsize=figsize, sharex=True, sharey=True)
axes = axes.flatten()
legends_info = []


# === Generate main figure ===
for ax, (col, title, pal) in zip(axes, plot_configs):
    plot_annotation_wo_HE(ax, col, title, pal)

plt.tight_layout()
plt.show()

# === Save main panel as lightweight SVG ===
fig.savefig(f"{save_prefix}_panel.svg", dpi=150, bbox_inches="tight", transparent=True)
print(f"Saved main figure: {save_prefix}_panel.svg")


# === Legends-only figure ===
fig_leg, axes_leg = plt.subplots(3, 2, figsize=(figsize[0], 8), squeeze=False)
axes_leg = axes_leg.flatten()

for ax, (title, handles, labels) in zip(axes_leg, legends_info):
    ax.axis("off")
    ax.set_title(f"{title}", fontsize=14, fontweight="normal", pad=6)
    ax.legend(
        handles, labels,
        loc="center left",
        bbox_to_anchor=(0.0, 0.5),
        frameon=False,
        ncol=2,
        handletextpad=0.6,
        labelspacing=0.4,
        borderaxespad=0.0,
        markerscale=3,
        fontsize=12,
    )

plt.tight_layout()
plt.show()

# === Save legends as separate SVG ===
fig_leg.savefig(f"{save_prefix}_legends.svg", dpi=250, bbox_inches="tight", transparent=True)
print(f"Saved legends: {save_prefix}_legends.svg")


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# === Parameters ===
#xlim = (7500, 17500)
#ylim = (28000, 38000)
xlim = (8500, 16500)
ylim = (30000, 36000)
point_size = 5
figsize = (14, 17)
save_prefix = "../../../SpatialFusion/results/figures_Fig2/OVCA_fallopian_he"

# --- Crop region of interest from the H&E image ---
# assume 'wsi' is your RGB numpy array of the full H&E
x0, x1 = xlim
y0, y1 = ylim
roi = wsi[y0:y1, x0:x1, :]  # numpy is row-major (y, x)

# --- Define which columns and palettes to use ---
plot_configs = [
    ("leiden_gcn", "SpatialFusion", palette_dict_2["leiden_gcn"]),
    ("leiden_nichecompass", "NicheCompass", palette_dict_2["leiden_nichecompass"]),
    ("leiden_banksy_08", "BANKSY", palette_dict_2["leiden_banksy_08"]),
    ("path_region", "Pathologist Region", palette_dict_1["path_region"]),
    ("major_celltype", "Major Cell Type", palette_dict_1["major_celltype"]),
    ("minor_celltype", "Minor Cell Type", palette_dict_1["minor_celltype"]),
]

# --- Set up figure grid ---
fig, axes = plt.subplots(3, 2, figsize=figsize, sharex=True, sharey=True)
axes = axes.flatten()
legends_info = []

# --- Generate panels ---
for ax, (col, title, pal) in zip(axes, plot_configs):
    plot_annotation_with_HE(ax, col, title, pal)

plt.tight_layout()
plt.show()

# === Save main composite figure ===
fig.savefig(f"{save_prefix}_panel.svg", dpi=200, bbox_inches="tight", transparent=True)
print(f"✅ Saved main figure: {save_prefix}_panel.svg")

# === Separate legend figure ===
fig_leg, axes_leg = plt.subplots(3, 2, figsize=(figsize[0], 8), squeeze=False)
axes_leg = axes_leg.flatten()

for ax, (title, handles, labels) in zip(axes_leg, legends_info):
    ax.axis("off")
    ax.set_title(f"{title}", fontsize=14, fontweight="normal", pad=6)
    ax.legend(
        handles, labels,
        loc="center left",
        bbox_to_anchor=(0.0, 0.5),
        frameon=False,
        ncol=2,
        handletextpad=0.6,
        labelspacing=0.4,
        borderaxespad=0.0,
        markerscale=3,
        fontsize=12,
    )

plt.tight_layout()
plt.show()

fig_leg.savefig(f"{save_prefix}_legends.svg", dpi=250, bbox_inches="tight", transparent=True)
print(f"✅ Saved legends: {save_prefix}_legends.svg")


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# === Parameters ===
xlim = (10000, 17500)
ylim = (10000, 17500)
point_size = 3
figsize = (14, 17)
save_prefix = "../../../SpatialFusion/results/figures_Fig2/OVCA_necrotic_he"

# --- Crop region of interest from the H&E image ---
# assume 'wsi' is your RGB numpy array of the full H&E
x0, x1 = xlim
y0, y1 = ylim
roi = wsi[y0:y1, x0:x1, :]  # numpy is row-major (y, x)

# --- Define which columns and palettes to use ---
plot_configs = [
    ("leiden_gcn", "SpatialFusion", palette_dict_2["leiden_gcn"]),
    ("leiden_nichecompass", "NicheCompass", palette_dict_2["leiden_nichecompass"]),
    ("leiden_banksy_08", "BANKSY", palette_dict_2["leiden_banksy_08"]),
    ("path_region", "Pathologist Region", palette_dict_1["path_region"]),
    ("major_celltype", "Major Cell Type", palette_dict_1["major_celltype"]),
    ("minor_celltype", "Minor Cell Type", palette_dict_1["minor_celltype"]),
]

# --- Set up figure grid ---
fig, axes = plt.subplots(3, 2, figsize=figsize, sharex=True, sharey=True)
axes = axes.flatten()
legends_info = []

# --- Generate panels ---
for ax, (col, title, pal) in zip(axes, plot_configs):
    plot_annotation_with_HE(ax, col, title, pal)

plt.tight_layout()
plt.show()

# === Save main composite figure ===
fig.savefig(f"{save_prefix}_panel.svg", dpi=200, bbox_inches="tight", transparent=True)
print(f"✅ Saved main figure: {save_prefix}_panel.svg")

# === Separate legend figure ===
fig_leg, axes_leg = plt.subplots(3, 2, figsize=(figsize[0], 8), squeeze=False)
axes_leg = axes_leg.flatten()

for ax, (title, handles, labels) in zip(axes_leg, legends_info):
    ax.axis("off")
    ax.set_title(f"{title}", fontsize=14, fontweight="normal", pad=6)
    ax.legend(
        handles, labels,
        loc="center left",
        bbox_to_anchor=(0.0, 0.5),
        frameon=False,
        ncol=2,
        handletextpad=0.6,
        labelspacing=0.4,
        borderaxespad=0.0,
        markerscale=3,
        fontsize=12,
    )

plt.tight_layout()
plt.show()

fig_leg.savefig(f"{save_prefix}_legends.svg", dpi=250, bbox_inches="tight", transparent=True)
print(f"✅ Saved legends: {save_prefix}_legends.svg")


## Pathway activity

In [None]:
import matplotlib
matplotlib.rcParams['svg.fonttype'] = 'none'

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np

def plot_annotation_pathway(
    ax,
    column,
    title,
    palette,
    vmin=-5,
    vmax=5,
    point_size=2,
    xlim=None,
    ylim=None,
    colorbar_info=None,
    legends_info=None,
):
    """
    Nature Genetics–style overlay for continuous or categorical variables.
    - Continuous: uses cmap with consistent scaling and optional colorbar.
    - Categorical: uses seaborn palette and saves legend info for later.

    Parameters
    ----------
    ax : matplotlib.axes.Axes
        Axis to draw on.
    column : str
        Column in adata.obs to plot.
    title : str
        Panel title (not bold).
    palette : str or dict
        Colormap for continuous or palette dict for categorical data.
    vmin, vmax : float
        Range for clipping continuous values.
    point_size : int or float
        Marker size.
    xlim, ylim : tuple
        Plot bounds (optional).
    colorbar_info, legends_info : list
        Lists to store handles for colorbars or legends.
    """

    sns.set_style("white")
    sns.set_context("talk")

    # --- check data column ---
    if column not in adata.obs.columns:
        raise KeyError(f"Column '{column}' not found in adata.obs")

    values = adata.obs[column]
    if isinstance(values, pd.DataFrame):
        raise ValueError(f"Column '{column}' has multiple matches in adata.obs")

    # --- continuous variable ---
    if pd.api.types.is_numeric_dtype(values):
        clipped_values = values.clip(lower=vmin, upper=vmax)
        sc = ax.scatter(
            adata.obs["X_he"], adata.obs["Y_he"],
            c=clipped_values.values,
            cmap=palette,
            s=point_size,
            alpha=0.8,
            linewidth=0,
            vmin=vmin,
            vmax=vmax,
            rasterized=True,
        )
        if colorbar_info is not None:
            colorbar_info.append((sc, title))

    # --- categorical variable ---
    else:
        sns.scatterplot(
            data=adata.obs,
            x="X_he",
            y="Y_he",
            hue=column,
            palette=palette,
            s=point_size,
            ax=ax,
            linewidth=0,
            alpha=0.9,
            legend=False,
            rasterized=True,
        )
        if legends_info is not None:
            handles, labels = ax.get_legend_handles_labels()
            by_label = {l: h for h, l in zip(handles, labels) if l and l != "_nolegend_"}
            legends_info.append((title, list(by_label.values()), list(by_label.keys())))

    # --- aesthetic cleanup ---
    
    if xlim is not None:
        ax.set_xlim(xlim)
    if ylim is not None:
        ax.set_ylim(ylim)
    ax.invert_yaxis()
    ax.set_aspect("equal")
    ax.set_xticks([]); ax.set_yticks([])
    ax.set_xlabel(""); ax.set_ylabel("")
    ax.set_title(title, fontsize=16, fontweight="normal", pad=6)

    for spine in ["top", "right", "left", "bottom"]:
        ax.spines[spine].set_visible(False)
    ax.grid(False)


In [None]:
# === Parameters ===
xlim = (10000, 20000)
ylim = (5000, 15000)
point_size = 3
figsize = (14, 20)
save_prefix = "../../../SpatialFusion/results/figures_Fig2/OVCA_pathway_activity"

# === Plot settings ===
plot_configs = [
    ("Androgen", "Androgen Activity", "vlag"),
    ("EGFR", "EGFR Activity", "vlag"),
    ("Estrogen", "Estrogen Activity", "vlag"), 
    ("JAK-STAT", "JAK-STAT Activity", "vlag"),
    ("MAPK", "MAPK Activity", "vlag"),
    ("NFkB", "NFkB Activity", "vlag"),
    ("PI3K", "PI3K Activity", "vlag"),
    ("TGFb", "TGFb Activity", "vlag"),
    ("TNFa", "TNFa Activity", "vlag"),
    ("VEGF", "VEGF Activity", "vlag"),
 
]

# === Create subplots ===
fig, axes = plt.subplots(5, 2, figsize=figsize, sharex=True, sharey=True)
axes = axes.flatten()

legends_info = []
colorbar_info = []  # store (ax, column, cmap) for continuous plots

# === Generate plots ===
for ax, (col, title, pal) in zip(axes, plot_configs):
    plot_annotation_pathway(ax, col, title, pal, xlim=xlim,
    ylim=ylim,)

plt.tight_layout()
plt.show()

fig.savefig(f"{save_prefix}_panel.svg", dpi=150, bbox_inches="tight", transparent=True)
print(f"✅ Saved panel: {save_prefix}_panel.svg")

In [None]:
# === Parameters ===
point_size = 3
figsize = (14, 25)
save_prefix = "../../../SpatialFusion/results/figures_Fig2/OVCA_pathway_activity_full"

# === Plot settings ===
plot_configs = [
    ("Androgen", "Androgen Activity", "vlag"),
    ("EGFR", "EGFR Activity", "vlag"),
    ("Estrogen", "Estrogen Activity", "vlag"), 
    ("JAK-STAT", "JAK-STAT Activity", "vlag"),
    ("MAPK", "MAPK Activity", "vlag"),
    ("NFkB", "NFkB Activity", "vlag"),
    ("PI3K", "PI3K Activity", "vlag"),
    ("TGFb", "TGFb Activity", "vlag"),
    ("TNFa", "TNFa Activity", "vlag"),
    ("VEGF", "VEGF Activity", "vlag"),
 
]

# === Create subplots ===
fig, axes = plt.subplots(5, 2, figsize=figsize, sharex=True, sharey=True)
axes = axes.flatten()

legends_info = []
colorbar_info = []  # store (ax, column, cmap) for continuous plots

# === Generate plots ===
for ax, (col, title, pal) in zip(axes, plot_configs):
    plot_annotation(
        ax,
        adata,
        column=col,
        title=title,
        palette=pal,
        vmin=-3, vmax=3,
        colorbar_info=colorbar_info,
    )

plt.tight_layout()
plt.show()


fig.savefig(f"{save_prefix}_panel.svg", dpi=150, bbox_inches="tight", transparent=True)
print(f"✅ Saved panel: {save_prefix}_panel.svg")


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# === Parameters ===
xlim = (16500, 19000)
ylim = (7500, 10000)
point_size = 10
figsize = (14, 15)
save_prefix = "../../../SpatialFusion/results/figures_Fig2/OVCA_pathway_cluster"

# --- Crop region of interest from the H&E image ---
# assume 'wsi' is your RGB numpy array of the full H&E
x0, x1 = xlim
y0, y1 = ylim
roi = wsi[y0:y1, x0:x1, :]  # numpy is row-major (y, x)

# --- Define which columns and palettes to use ---
plot_configs = [
    ("leiden_gcn", "SpatialFusion", palette_dict_2["leiden_gcn"]),
    ("leiden_nichecompass", "NicheCompass", palette_dict_2["leiden_nichecompass"]),
    ("leiden_banksy_08", "BANKSY", palette_dict_2["leiden_banksy_08"]),
    ("path_region", "Pathologist Region", palette_dict_1["path_region"]),
    ("major_celltype", "Major Cell Type", palette_dict_1["major_celltype"]),
    ("minor_celltype", "Minor Cell Type", palette_dict_1["minor_celltype"]),
]

# --- Set up figure grid ---
fig, axes = plt.subplots(3, 2, figsize=figsize, sharex=True, sharey=True)
axes = axes.flatten()
legends_info = []


# --- Generate panels ---
for ax, (col, title, pal) in zip(axes, plot_configs):
    plot_annotation_with_HE(ax, col, title, pal)

plt.tight_layout()
plt.show()

# === Save main composite figure ===
fig.savefig(f"{save_prefix}_panel.svg", dpi=200, bbox_inches="tight", transparent=True)
print(f"✅ Saved main figure: {save_prefix}_panel.svg")

# === Separate legend figure ===
fig_leg, axes_leg = plt.subplots(3, 2, figsize=(figsize[0], 8), squeeze=False)
axes_leg = axes_leg.flatten()

for ax, (title, handles, labels) in zip(axes_leg, legends_info):
    ax.axis("off")
    ax.set_title(f"{title}", fontsize=14, fontweight="normal", pad=6)
    ax.legend(
        handles, labels,
        loc="center left",
        bbox_to_anchor=(0.0, 0.5),
        frameon=False,
        ncol=2,
        handletextpad=0.6,
        labelspacing=0.4,
        borderaxespad=0.0,
        markerscale=3,
        fontsize=12,
    )

plt.tight_layout()
plt.show()

fig_leg.savefig(f"{save_prefix}_legends.svg", dpi=250, bbox_inches="tight", transparent=True)
print(f"✅ Saved legends: {save_prefix}_legends.svg")


In [None]:
# === Parameters ===
xlim = (16500, 19000)
ylim = (7500, 10000)
point_size = 20
figsize = (14, 20)
save_prefix = "../../../SpatialFusion/results/figures_Fig2/OVCA_pathway_activity_zoom"

# === Plot settings ===
plot_configs = [
    ("Androgen", "Androgen Activity", "vlag"),
    ("EGFR", "EGFR Activity", "vlag"),
    ("Estrogen", "Estrogen Activity", "vlag"), 
    ("JAK-STAT", "JAK-STAT Activity", "vlag"),
    ("MAPK", "MAPK Activity", "vlag"),
    ("NFkB", "NFkB Activity", "vlag"),
    ("PI3K", "PI3K Activity", "vlag"),
    ("TGFb", "TGFb Activity", "vlag"),
    ("TNFa", "TNFa Activity", "vlag"),
    ("VEGF", "VEGF Activity", "vlag"),
 
]

# === Create subplots ===
fig, axes = plt.subplots(5, 2, figsize=figsize, sharex=True, sharey=True)
axes = axes.flatten()

legends_info = []
colorbar_info = []  # store (ax, column, cmap) for continuous plots

# === Generate plots ===
for ax, (col, title, pal) in zip(axes, plot_configs):
    plot_annotation_pathway(ax, col, title, pal, xlim=xlim,
    ylim=ylim, point_size=point_size)

plt.tight_layout()
plt.show()

fig.savefig(f"{save_prefix}_panel.svg", dpi=150, bbox_inches="tight", transparent=True)
print(f"✅ Saved panel: {save_prefix}_panel.svg")

# Co-occurence

In [None]:
# Step 1: Compute the cross-tabulation (co-occurrence counts)
heatmap_data = pd.crosstab(adata.obs['path_region'], adata.obs['leiden_gcn'])
heatmap_data = heatmap_data.div(heatmap_data.sum(axis=1), axis=0)*100
heatmap_data = heatmap_data.astype(int)

# Step 2: Plot the heatmap
sns.heatmap(heatmap_data, annot=True, fmt="d", cmap="Blues")
plt.xlabel('SpatialFusion niches')
plt.ylabel('Region annotations')
plt.title('SpatialFusion co-occurrence')
plt.savefig('../../../SpatialFusion/results/figures_Fig2/SpatialFusion_confusion.svg', dpi=200, bbox_inches='tight')
plt.show()


In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

# Step 1: Compute the cross-tabulation (co-occurrence counts)
heatmap_data = pd.crosstab(adata.obs['path_region'], adata.obs['leiden_gcn_he'])
heatmap_data = heatmap_data.div(heatmap_data.sum(axis=1), axis=0)*100
heatmap_data = heatmap_data.astype(int)

# Step 2: Plot the heatmap
sns.heatmap(heatmap_data, annot=True, fmt="d", cmap="Blues")
plt.xlabel('SpatialFusion (H&E) niches')
plt.ylabel('Region annotations')
plt.title('SpatialFusion (H&E) co-occurrence')
plt.savefig('../../../SpatialFusion/results/figures_Fig2/SpatialFusion_he_confusion.svg', dpi=200, bbox_inches='tight')
plt.show()


In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

# Step 1: Compute the cross-tabulation (co-occurrence counts)
heatmap_data = pd.crosstab(adata.obs['path_region'], adata.obs['leiden_nichecompass'])
heatmap_data = heatmap_data.div(heatmap_data.sum(axis=1), axis=0)*100
heatmap_data = heatmap_data.astype(int)

# Step 2: Plot the heatmap
sns.heatmap(heatmap_data, annot=True, fmt="d", cmap="Blues")
plt.xlabel('NicheCompass niches')
plt.ylabel('Region annotations')
plt.title('NicheCompass co-occurence')
plt.savefig('../../../SpatialFusion/results/figures_Fig2/NicheCompass_confusion.svg', dpi=200, bbox_inches='tight')
plt.show()


In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

# Step 1: Compute the cross-tabulation (co-occurrence counts)
heatmap_data = pd.crosstab(adata.obs['path_region'], adata.obs['leiden_banksy_08'])
heatmap_data = heatmap_data.div(heatmap_data.sum(axis=1), axis=0)*100
heatmap_data = heatmap_data.astype(int)

# Step 2: Plot the heatmap
sns.heatmap(heatmap_data, annot=True, fmt="d", cmap="Blues")
plt.xlabel('BANKSY niches')
plt.ylabel('Region annotations')
plt.title('BANKSY co-occurence')
plt.savefig('../../../SpatialFusion/results/figures_Fig2/BANKSY_confusion.svg', dpi=200, bbox_inches='tight')
plt.show()


# EXTRA: Comparing different cluster numbers

In [None]:
# this is to re-read pre-computed clustering 
#adata_obs= pd.read_csv('benchmark_ovarian_adata_obs_cl9.csv',index_col=0)
#adata_obs.index = adata_obs.index.astype(str)

#adata.obs = adata_obs

In [None]:
adata.obs = adata.obs.iloc[:,:20]

In [None]:
sc.pp.neighbors(adata, use_rep = 'gcn')

In [None]:
sc.tl.leiden(adata, resolution=0.1, flavor="igraph", n_iterations=2)

In [None]:
# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)

# Make mapping: old → new (ranked by size)
mapping = {old: str(new) for new, old in enumerate(counts.index)}

# Apply mapping
adata.obs['leiden'] = adata.obs['leiden'].map(mapping).astype('category')

# (optional) sort categories by their new numeric label
adata.obs['leiden'].cat.reorder_categories(sorted(adata.obs['leiden'].cat.categories, key=int))

print("Cluster relabeling done ✅")# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)
display(counts)

In [None]:
adata.obs['leiden_gcn'] = adata.obs.leiden.replace({'9': '8', '10': '8', '11': '8', '12': '8',
                                                   '13': '8', '14': '8', '15': '8', '16': '8',
                                                   '17': '8',})

In [None]:
sc.pp.neighbors(adata, use_rep = 'gcn_concat')

In [None]:
sc.tl.leiden(adata, resolution=0.1, flavor="igraph", n_iterations=2)

In [None]:
# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)

# Make mapping: old → new (ranked by size)
mapping = {old: str(new) for new, old in enumerate(counts.index)}

# Apply mapping
adata.obs['leiden'] = adata.obs['leiden'].map(mapping).astype('category')

# (optional) sort categories by their new numeric label
adata.obs['leiden'].cat.reorder_categories(sorted(adata.obs['leiden'].cat.categories, key=int))

print("Cluster relabeling done ✅")# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)
display(counts)

In [None]:
adata.obs['leiden_gcn_concat'] = adata.obs.leiden.replace({'9': '8', '10': '8', '11': '8', '12': '8', '13': '8',})

In [None]:
sc.pp.neighbors(adata, use_rep = 'gcn_he')

In [None]:
sc.tl.leiden(adata, resolution=0.07, flavor="igraph", n_iterations=2)

In [None]:
# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)

# Make mapping: old → new (ranked by size)
mapping = {old: str(new) for new, old in enumerate(counts.index)}

# Apply mapping
adata.obs['leiden'] = adata.obs['leiden'].map(mapping).astype('category')

# (optional) sort categories by their new numeric label
adata.obs['leiden'].cat.reorder_categories(sorted(adata.obs['leiden'].cat.categories, key=int))

print("Cluster relabeling done ✅")# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)
display(counts)

In [None]:
adata.obs['leiden_gcn_he'] = adata.obs.leiden.replace({'9': '8', '10': '8', '11': '8', '12': '8', '13': '8',})

In [None]:
sc.pp.neighbors(adata, use_rep = 'gcn_rna')

In [None]:
sc.tl.leiden(adata, resolution=0.1, flavor="igraph", n_iterations=2)

In [None]:
# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)

# Make mapping: old → new (ranked by size)
mapping = {old: str(new) for new, old in enumerate(counts.index)}

# Apply mapping
adata.obs['leiden'] = adata.obs['leiden'].map(mapping).astype('category')

# (optional) sort categories by their new numeric label
adata.obs['leiden'].cat.reorder_categories(sorted(adata.obs['leiden'].cat.categories, key=int))

print("Cluster relabeling done ✅")# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)
display(counts)

In [None]:
adata.obs['leiden_gcn_rna'] = adata.obs.leiden.replace({'9': '8', '10': '8', '11': '8', '12': '8', '13': '8',
                                                          '14': '8', '15': '8',})

In [None]:
sc.pp.neighbors(adata, use_rep = 'gcn_onlyrecon')

In [None]:
sc.tl.leiden(adata, resolution=0.07, flavor="igraph", n_iterations=2)

In [None]:
# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)

# Make mapping: old → new (ranked by size)
mapping = {old: str(new) for new, old in enumerate(counts.index)}

# Apply mapping
adata.obs['leiden'] = adata.obs['leiden'].map(mapping).astype('category')

# (optional) sort categories by their new numeric label
adata.obs['leiden'].cat.reorder_categories(sorted(adata.obs['leiden'].cat.categories, key=int))

print("Cluster relabeling done ✅")# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)
display(counts)

In [None]:
adata.obs['leiden_gcn_onlyrecon'] = adata.obs.leiden.replace({'9': '8', '10': '8', '11': '8', '12': '8',
                                                              '13': '8', '14': '8', '15': '8',})

In [None]:
sc.pp.neighbors(adata, use_rep = 'gcn_onlyrecon_concat')

In [None]:
sc.tl.leiden(adata, resolution=0.07, flavor="igraph", n_iterations=2)

In [None]:
# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)

# Make mapping: old → new (ranked by size)
mapping = {old: str(new) for new, old in enumerate(counts.index)}

# Apply mapping
adata.obs['leiden'] = adata.obs['leiden'].map(mapping).astype('category')

# (optional) sort categories by their new numeric label
adata.obs['leiden'].cat.reorder_categories(sorted(adata.obs['leiden'].cat.categories, key=int))

print("Cluster relabeling done ✅")# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)
display(counts)

In [None]:
adata.obs['leiden_gcn_onlyrecon_concat'] = adata.obs.leiden.replace({'9': '8', '10': '8', '11': '8', '12': '8',
                                                                    '13': '8', '14': '8', '15': '8', '16': '8',
                                                                    '17': '8', '18': '8',})

In [None]:
sc.pp.neighbors(adata, use_rep = 'gcn_onlyrecon_he')

In [None]:
sc.tl.leiden(adata, resolution=0.07, flavor="igraph", n_iterations=2)

In [None]:
# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)

# Make mapping: old → new (ranked by size)
mapping = {old: str(new) for new, old in enumerate(counts.index)}

# Apply mapping
adata.obs['leiden'] = adata.obs['leiden'].map(mapping).astype('category')

# (optional) sort categories by their new numeric label
adata.obs['leiden'].cat.reorder_categories(sorted(adata.obs['leiden'].cat.categories, key=int))

print("Cluster relabeling done ✅")# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)
display(counts)

In [None]:
adata.obs['leiden_gcn_onlyrecon_he'] = adata.obs.leiden.replace({'9': '8', '10': '8', '11': '8', '12': '8',
                                                                '13': '8', '14': '8',})

In [None]:
sc.pp.neighbors(adata, use_rep = 'gcn_onlyrecon_rna')

In [None]:
sc.tl.leiden(adata, resolution=0.1, flavor="igraph", n_iterations=2)

In [None]:
# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)

# Make mapping: old → new (ranked by size)
mapping = {old: str(new) for new, old in enumerate(counts.index)}

# Apply mapping
adata.obs['leiden'] = adata.obs['leiden'].map(mapping).astype('category')

# (optional) sort categories by their new numeric label
adata.obs['leiden'].cat.reorder_categories(sorted(adata.obs['leiden'].cat.categories, key=int))

print("Cluster relabeling done ✅")# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)
display(counts)

In [None]:
adata.obs['leiden_gcn_onlyrecon_rna'] = adata.obs.leiden.replace({'9': '8', '10': '8',})

In [None]:
sc.pp.neighbors(adata, use_rep = 'banksy')

In [None]:
sc.tl.leiden(adata, resolution=0.3, flavor="igraph", n_iterations=2)

In [None]:
# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)

# Make mapping: old → new (ranked by size)
mapping = {old: str(new) for new, old in enumerate(counts.index)}

# Apply mapping
adata.obs['leiden'] = adata.obs['leiden'].map(mapping).astype('category')

# (optional) sort categories by their new numeric label
adata.obs['leiden'].cat.reorder_categories(sorted(adata.obs['leiden'].cat.categories, key=int))

print("Cluster relabeling done ✅")
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)
display(counts)

In [None]:
adata.obs['leiden_banksy'] = adata.obs.leiden.replace({'9': '8', '10': '8',})

In [None]:
sc.pp.neighbors(adata, use_rep = 'OmiCLIP_text')

In [None]:
sc.tl.leiden(adata, resolution=0.35, flavor="igraph", n_iterations=2)

In [None]:
# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)

# Make mapping: old → new (ranked by size)
mapping = {old: str(new) for new, old in enumerate(counts.index)}

# Apply mapping
adata.obs['leiden'] = adata.obs['leiden'].map(mapping).astype('category')

# (optional) sort categories by their new numeric label
adata.obs['leiden'].cat.reorder_categories(sorted(adata.obs['leiden'].cat.categories, key=int))

print("Cluster relabeling done ✅")
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)
display(counts)

In [None]:
adata.obs['leiden_OmiCLIP_text'] = adata.obs.leiden.replace({'9': '8',})

In [None]:
sc.pp.neighbors(adata, use_rep = 'OmiCLIP_image')

In [None]:
sc.tl.leiden(adata, resolution=0.14, flavor="igraph", n_iterations=2)

In [None]:
# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)

# Make mapping: old → new (ranked by size)
mapping = {old: str(new) for new, old in enumerate(counts.index)}

# Apply mapping
adata.obs['leiden'] = adata.obs['leiden'].map(mapping).astype('category')

# (optional) sort categories by their new numeric label
adata.obs['leiden'].cat.reorder_categories(sorted(adata.obs['leiden'].cat.categories, key=int))

print("Cluster relabeling done ✅")
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)
display(counts)

In [None]:
adata.obs['leiden_OmiCLIP_image'] = adata.obs.leiden.replace({'9': '8', '10': '8', '11': '8',})

In [None]:
sc.pp.neighbors(adata, use_rep = 'nichecompass')

In [None]:
sc.tl.leiden(adata, resolution=0.2, flavor="igraph", n_iterations=2)

In [None]:
# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)

# Make mapping: old → new (ranked by size)
mapping = {old: str(new) for new, old in enumerate(counts.index)}

# Apply mapping
adata.obs['leiden'] = adata.obs['leiden'].map(mapping).astype('category')

# (optional) sort categories by their new numeric label
adata.obs['leiden'].cat.reorder_categories(sorted(adata.obs['leiden'].cat.categories, key=int))

print("Cluster relabeling done ✅")
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)
display(counts)

In [None]:
adata.obs['leiden_nichecompass'] = adata.obs.leiden.replace({'9': '8', '10': '8',})

In [None]:
sc.pp.neighbors(adata, use_rep = 'nicheformer')

In [None]:
sc.tl.leiden(adata, resolution=0.4, flavor="igraph", n_iterations=2)

In [None]:
# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)

# Make mapping: old → new (ranked by size)
mapping = {old: str(new) for new, old in enumerate(counts.index)}

# Apply mapping
adata.obs['leiden'] = adata.obs['leiden'].map(mapping).astype('category')

# (optional) sort categories by their new numeric label
adata.obs['leiden'].cat.reorder_categories(sorted(adata.obs['leiden'].cat.categories, key=int))

print("Cluster relabeling done ✅")
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)
display(counts)

In [None]:
adata.obs['leiden_nicheformer'] = adata.obs.leiden

In [None]:
sc.pp.neighbors(adata, use_rep = 'scgptspatial')

In [None]:
sc.tl.leiden(adata, resolution=0.2, flavor="igraph", n_iterations=2)

In [None]:
# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)

# Make mapping: old → new (ranked by size)
mapping = {old: str(new) for new, old in enumerate(counts.index)}

# Apply mapping
adata.obs['leiden'] = adata.obs['leiden'].map(mapping).astype('category')

# (optional) sort categories by their new numeric label
adata.obs['leiden'].cat.reorder_categories(sorted(adata.obs['leiden'].cat.categories, key=int))

print("Cluster relabeling done ✅")
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)
display(counts)

In [None]:
adata.obs['leiden_scgptspatial'] = adata.obs.leiden.replace({'9': '8', })

In [None]:
adata.layers['counts'] = adata.X.copy()

In [None]:
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)

sc.tl.pca(adata)

sc.pp.neighbors(adata, use_rep='X_pca')

In [None]:
sc.tl.leiden(adata, resolution=0.5, flavor="igraph", n_iterations=2)

In [None]:
# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)

# Make mapping: old → new (ranked by size)
mapping = {old: str(new) for new, old in enumerate(counts.index)}

# Apply mapping
adata.obs['leiden'] = adata.obs['leiden'].map(mapping).astype('category')

# (optional) sort categories by their new numeric label
adata.obs['leiden'].cat.reorder_categories(sorted(adata.obs['leiden'].cat.categories, key=int))

print("Cluster relabeling done ✅")
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)
display(counts)

In [None]:
adata.obs['leiden_scanpy'] = adata.obs.leiden

In [None]:
adata.obs.to_csv('benchmark_ovarian_adata_obs_cl9.csv')

In [None]:
adata_obs= pd.read_csv('benchmark_ovarian_adata_obs_cl9.csv', index_col=0)

adata.obs = adata_obs

In [None]:
clustering_keys = {
    'SpatialFusion': 'leiden_gcn',
    'SpatialFusion (concat)': 'leiden_gcn_concat',
    'SpatialFusion (H&E)': 'leiden_gcn_he',
    'SpatialFusion (RNA)': 'leiden_gcn_rna',
    'SpatialFusion (recon)': 'leiden_gcn_onlyrecon',
    'SpatialFusion (recon concat)': 'leiden_gcn_onlyrecon_concat',
    'SpatialFusion (recon H&E)': 'leiden_gcn_onlyrecon_he',
    'SpatialFusion (recon RNA)': 'leiden_gcn_onlyrecon_rna',
    'NicheCompass': 'leiden_nichecompass',
    'BANKSY': 'leiden_banksy',
    'Nicheformer': 'leiden_nicheformer',
    'scGPT-spatial': 'leiden_scgptspatial',
    'OmiCLIP text': 'leiden_OmiCLIP_text',
    'OmiCLIP image': 'leiden_OmiCLIP_image',
    'Scanpy': 'leiden_scanpy',
}

results_df = compute_all_metrics(adata, clustering_keys)


In [None]:
import matplotlib
matplotlib.rcParams['svg.fonttype'] = 'none'
plot_benchmark_heatmap(results_df, title="OVCA Benchmark", savefig='../../../SpatialFusion/results/figures_Fig2/OVCA_cl9_benchmark.svg')

# Cl = 11

In [None]:
adata.obs = adata.obs.iloc[:,:20]

In [None]:
sc.pp.neighbors(adata, use_rep = 'gcn')

In [None]:
sc.tl.leiden(adata, resolution=0.1, flavor="igraph", n_iterations=2)

In [None]:
# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)

# Make mapping: old → new (ranked by size)
mapping = {old: str(new) for new, old in enumerate(counts.index)}

# Apply mapping
adata.obs['leiden'] = adata.obs['leiden'].map(mapping).astype('category')

# (optional) sort categories by their new numeric label
adata.obs['leiden'].cat.reorder_categories(sorted(adata.obs['leiden'].cat.categories, key=int))

print("Cluster relabeling done ✅")# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)
display(counts)

In [None]:
adata.obs['leiden_gcn'] = adata.obs.leiden.replace({'11': '10', '12': '10', '13': '10',
                                                   '14': '10', '15': '10', '16': '10', '17': '10',})

In [None]:
sc.pp.neighbors(adata, use_rep = 'gcn_concat')

In [None]:
sc.tl.leiden(adata, resolution=0.15, flavor="igraph", n_iterations=2)

In [None]:
# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)

# Make mapping: old → new (ranked by size)
mapping = {old: str(new) for new, old in enumerate(counts.index)}

# Apply mapping
adata.obs['leiden'] = adata.obs['leiden'].map(mapping).astype('category')

# (optional) sort categories by their new numeric label
adata.obs['leiden'].cat.reorder_categories(sorted(adata.obs['leiden'].cat.categories, key=int))

print("Cluster relabeling done ✅")# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)
display(counts)

In [None]:
adata.obs['leiden_gcn_concat'] = adata.obs.leiden.replace({'11': '10', '12': '10', '13': '10',
                                                   '14': '10', '15': '10', })

In [None]:
sc.pp.neighbors(adata, use_rep = 'gcn_he')

In [None]:
sc.tl.leiden(adata, resolution=0.1, flavor="igraph", n_iterations=2)

In [None]:
# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)

# Make mapping: old → new (ranked by size)
mapping = {old: str(new) for new, old in enumerate(counts.index)}

# Apply mapping
adata.obs['leiden'] = adata.obs['leiden'].map(mapping).astype('category')

# (optional) sort categories by their new numeric label
adata.obs['leiden'].cat.reorder_categories(sorted(adata.obs['leiden'].cat.categories, key=int))

print("Cluster relabeling done ✅")# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)
display(counts)

In [None]:
adata.obs['leiden_gcn_he'] = adata.obs.leiden.replace({'11': '10', '12': '10', '13': '10',
                                                   '14': '10', '15': '10', '16': '10',})

In [None]:
sc.pp.neighbors(adata, use_rep = 'gcn_rna')

In [None]:
sc.tl.leiden(adata, resolution=0.15, flavor="igraph", n_iterations=2)

In [None]:
# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)

# Make mapping: old → new (ranked by size)
mapping = {old: str(new) for new, old in enumerate(counts.index)}

# Apply mapping
adata.obs['leiden'] = adata.obs['leiden'].map(mapping).astype('category')

# (optional) sort categories by their new numeric label
adata.obs['leiden'].cat.reorder_categories(sorted(adata.obs['leiden'].cat.categories, key=int))

print("Cluster relabeling done ✅")# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)
display(counts)

In [None]:
adata.obs['leiden_gcn_rna'] = adata.obs.leiden.replace({'11': '10', '12': '10', '13': '10',
                                                   '14': '10', '15': '10', '16': '10', '17': '10',})

In [None]:
sc.pp.neighbors(adata, use_rep = 'gcn_onlyrecon')

In [None]:
sc.tl.leiden(adata, resolution=0.1, flavor="igraph", n_iterations=2)

In [None]:
# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)

# Make mapping: old → new (ranked by size)
mapping = {old: str(new) for new, old in enumerate(counts.index)}

# Apply mapping
adata.obs['leiden'] = adata.obs['leiden'].map(mapping).astype('category')

# (optional) sort categories by their new numeric label
adata.obs['leiden'].cat.reorder_categories(sorted(adata.obs['leiden'].cat.categories, key=int))

print("Cluster relabeling done ✅")# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)
display(counts)

In [None]:
adata.obs['leiden_gcn_onlyrecon'] = adata.obs.leiden.replace({'11': '10', '12': '10', '13': '10',
                                                   '14': '10', '15': '10', '16': '10', '17': '10',
                                                             '18': '10', '19': '10',})

In [None]:
sc.pp.neighbors(adata, use_rep = 'gcn_onlyrecon_concat')

In [None]:
sc.tl.leiden(adata, resolution=0.1, flavor="igraph", n_iterations=2)

In [None]:
# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)

# Make mapping: old → new (ranked by size)
mapping = {old: str(new) for new, old in enumerate(counts.index)}

# Apply mapping
adata.obs['leiden'] = adata.obs['leiden'].map(mapping).astype('category')

# (optional) sort categories by their new numeric label
adata.obs['leiden'].cat.reorder_categories(sorted(adata.obs['leiden'].cat.categories, key=int))

print("Cluster relabeling done ✅")# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)
display(counts)

In [None]:
adata.obs['leiden_gcn_onlyrecon_concat'] = adata.obs.leiden.replace({'11': '10', '12': '10', '13': '10',
                                                   '14': '10', '15': '10', '16': '10', '17': '10',
                                                             '18': '10', '19': '10', '20': '10',})

In [None]:
sc.pp.neighbors(adata, use_rep = 'gcn_onlyrecon_he')

In [None]:
sc.tl.leiden(adata, resolution=0.1, flavor="igraph", n_iterations=2)

In [None]:
# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)

# Make mapping: old → new (ranked by size)
mapping = {old: str(new) for new, old in enumerate(counts.index)}

# Apply mapping
adata.obs['leiden'] = adata.obs['leiden'].map(mapping).astype('category')

# (optional) sort categories by their new numeric label
adata.obs['leiden'].cat.reorder_categories(sorted(adata.obs['leiden'].cat.categories, key=int))

print("Cluster relabeling done ✅")# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)
display(counts)

In [None]:
adata.obs['leiden_gcn_onlyrecon_he'] = adata.obs.leiden.replace({'11': '10', '12': '10', '13': '10',
                                                   '14': '10', '15': '10', '16': '10', })

In [None]:
sc.pp.neighbors(adata, use_rep = 'gcn_onlyrecon_rna')

In [None]:
sc.tl.leiden(adata, resolution=0.18, flavor="igraph", n_iterations=2)

In [None]:
# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)

# Make mapping: old → new (ranked by size)
mapping = {old: str(new) for new, old in enumerate(counts.index)}

# Apply mapping
adata.obs['leiden'] = adata.obs['leiden'].map(mapping).astype('category')

# (optional) sort categories by their new numeric label
adata.obs['leiden'].cat.reorder_categories(sorted(adata.obs['leiden'].cat.categories, key=int))

print("Cluster relabeling done ✅")# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)
display(counts)

In [None]:
adata.obs['leiden_gcn_onlyrecon_rna'] = adata.obs.leiden.replace({'11': '10', '12': '10', '13': '10',
                                                   '14': '10', '15': '10', '16': '10', })

In [None]:
sc.pp.neighbors(adata, use_rep = 'banksy')

In [None]:
sc.tl.leiden(adata, resolution=0.4, flavor="igraph", n_iterations=2)

In [None]:
# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)

# Make mapping: old → new (ranked by size)
mapping = {old: str(new) for new, old in enumerate(counts.index)}

# Apply mapping
adata.obs['leiden'] = adata.obs['leiden'].map(mapping).astype('category')

# (optional) sort categories by their new numeric label
adata.obs['leiden'].cat.reorder_categories(sorted(adata.obs['leiden'].cat.categories, key=int))

print("Cluster relabeling done ✅")
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)
display(counts)

In [None]:
adata.obs['leiden_banksy'] = adata.obs.leiden.replace({'11': '10', '12': '10', '13': '10',
                                                   '14': '10',})

In [None]:
sc.pp.neighbors(adata, use_rep = 'OmiCLIP_text')

In [None]:
sc.tl.leiden(adata, resolution=0.5, flavor="igraph", n_iterations=2)

In [None]:
# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)

# Make mapping: old → new (ranked by size)
mapping = {old: str(new) for new, old in enumerate(counts.index)}

# Apply mapping
adata.obs['leiden'] = adata.obs['leiden'].map(mapping).astype('category')

# (optional) sort categories by their new numeric label
adata.obs['leiden'].cat.reorder_categories(sorted(adata.obs['leiden'].cat.categories, key=int))

print("Cluster relabeling done ✅")
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)
display(counts)

In [None]:
adata.obs['leiden_OmiCLIP_text'] = adata.obs.leiden

In [None]:
sc.pp.neighbors(adata, use_rep = 'OmiCLIP_image')

In [None]:
sc.tl.leiden(adata, resolution=0.2, flavor="igraph", n_iterations=2)

In [None]:
# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)

# Make mapping: old → new (ranked by size)
mapping = {old: str(new) for new, old in enumerate(counts.index)}

# Apply mapping
adata.obs['leiden'] = adata.obs['leiden'].map(mapping).astype('category')

# (optional) sort categories by their new numeric label
adata.obs['leiden'].cat.reorder_categories(sorted(adata.obs['leiden'].cat.categories, key=int))

print("Cluster relabeling done ✅")
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)
display(counts)

In [None]:
adata.obs['leiden_OmiCLIP_image'] = adata.obs.leiden.replace({'11': '10', '12': '10', })

In [None]:
sc.pp.neighbors(adata, use_rep = 'nichecompass')

In [None]:
sc.tl.leiden(adata, resolution=0.2, flavor="igraph", n_iterations=2)

In [None]:
# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)

# Make mapping: old → new (ranked by size)
mapping = {old: str(new) for new, old in enumerate(counts.index)}

# Apply mapping
adata.obs['leiden'] = adata.obs['leiden'].map(mapping).astype('category')

# (optional) sort categories by their new numeric label
adata.obs['leiden'].cat.reorder_categories(sorted(adata.obs['leiden'].cat.categories, key=int))

print("Cluster relabeling done ✅")
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)
display(counts)

In [None]:
adata.obs['leiden_nichecompass'] = adata.obs.leiden

In [None]:
sc.pp.neighbors(adata, use_rep = 'nicheformer')

In [None]:
sc.tl.leiden(adata, resolution=0.5, flavor="igraph", n_iterations=2)

In [None]:
# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)

# Make mapping: old → new (ranked by size)
mapping = {old: str(new) for new, old in enumerate(counts.index)}

# Apply mapping
adata.obs['leiden'] = adata.obs['leiden'].map(mapping).astype('category')

# (optional) sort categories by their new numeric label
adata.obs['leiden'].cat.reorder_categories(sorted(adata.obs['leiden'].cat.categories, key=int))

print("Cluster relabeling done ✅")
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)
display(counts)

In [None]:
adata.obs['leiden_nicheformer'] = adata.obs.leiden

In [None]:
sc.pp.neighbors(adata, use_rep = 'scgptspatial')

In [None]:
sc.tl.leiden(adata, resolution=0.3, flavor="igraph", n_iterations=2)

In [None]:
# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)

# Make mapping: old → new (ranked by size)
mapping = {old: str(new) for new, old in enumerate(counts.index)}

# Apply mapping
adata.obs['leiden'] = adata.obs['leiden'].map(mapping).astype('category')

# (optional) sort categories by their new numeric label
adata.obs['leiden'].cat.reorder_categories(sorted(adata.obs['leiden'].cat.categories, key=int))

print("Cluster relabeling done ✅")
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)
display(counts)

In [None]:
adata.obs['leiden_scgptspatial'] = adata.obs.leiden.replace({'11': '10', '12': '10', })

In [None]:
#adata.layers['counts'] = adata.X.copy()

#sc.pp.normalize_total(adata, target_sum=1e4)
#sc.pp.log1p(adata)

In [None]:
sc.tl.pca(adata)

sc.pp.neighbors(adata, use_rep='X_pca')

In [None]:
sc.tl.leiden(adata, resolution=0.8, flavor="igraph", n_iterations=2)

In [None]:
# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)

# Make mapping: old → new (ranked by size)
mapping = {old: str(new) for new, old in enumerate(counts.index)}

# Apply mapping
adata.obs['leiden'] = adata.obs['leiden'].map(mapping).astype('category')

# (optional) sort categories by their new numeric label
adata.obs['leiden'].cat.reorder_categories(sorted(adata.obs['leiden'].cat.categories, key=int))

print("Cluster relabeling done ✅")
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)
display(counts)

In [None]:
adata.obs['leiden_scanpy'] = adata.obs.leiden

In [None]:
adata.obs.to_csv('benchmark_ovarian_adata_obs_cl11.csv')

In [None]:
clustering_keys = {
    'SpatialFusion': 'leiden_gcn',
    'SpatialFusion (concat)': 'leiden_gcn_concat',
    'SpatialFusion (H&E)': 'leiden_gcn_he',
    'SpatialFusion (RNA)': 'leiden_gcn_rna',
    'SpatialFusion (recon)': 'leiden_gcn_onlyrecon',
    'SpatialFusion (recon concat)': 'leiden_gcn_onlyrecon_concat',
    'SpatialFusion (recon H&E)': 'leiden_gcn_onlyrecon_he',
    'SpatialFusion (recon RNA)': 'leiden_gcn_onlyrecon_rna',
    'NicheCompass': 'leiden_nichecompass',
    'BANKSY': 'leiden_banksy',
    'Nicheformer': 'leiden_nicheformer',
    'scGPT-spatial': 'leiden_scgptspatial',
    'OmiCLIP text': 'leiden_OmiCLIP_text',
    'OmiCLIP image': 'leiden_OmiCLIP_image',
    'Scanpy': 'leiden_scanpy',
}

results_df = compute_all_metrics(adata, clustering_keys)


In [None]:
import matplotlib
matplotlib.rcParams['svg.fonttype'] = 'none'
plot_benchmark_heatmap(results_df, title="OVCA Benchmark",
                       savefig='../../../SpatialFusion/results/figures_Fig2/OVCA_cl11_benchmark.svg')

## CL=ground truth

In [None]:
adata.obs = adata.obs.iloc[:,:20]

In [None]:
sc.pp.neighbors(adata, use_rep = 'gcn')

In [None]:
sc.tl.leiden(adata, resolution=0.04, flavor="igraph", n_iterations=2)

In [None]:
# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)

# Make mapping: old → new (ranked by size)
mapping = {old: str(new) for new, old in enumerate(counts.index)}

# Apply mapping
adata.obs['leiden'] = adata.obs['leiden'].map(mapping).astype('category')

# (optional) sort categories by their new numeric label
adata.obs['leiden'].cat.reorder_categories(sorted(adata.obs['leiden'].cat.categories, key=int))

print("Cluster relabeling done ✅")# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)
display(counts)

In [None]:
adata.obs['leiden_gcn'] = adata.obs.leiden.replace({'6': '5', '7': '5', '8': '5', })

In [None]:
sc.pp.neighbors(adata, use_rep = 'gcn_concat')

In [None]:
sc.tl.leiden(adata, resolution=0.05, flavor="igraph", n_iterations=2)

In [None]:
# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)

# Make mapping: old → new (ranked by size)
mapping = {old: str(new) for new, old in enumerate(counts.index)}

# Apply mapping
adata.obs['leiden'] = adata.obs['leiden'].map(mapping).astype('category')

# (optional) sort categories by their new numeric label
adata.obs['leiden'].cat.reorder_categories(sorted(adata.obs['leiden'].cat.categories, key=int))

print("Cluster relabeling done ✅")# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)
display(counts)

In [None]:
adata.obs['leiden_gcn_concat'] = adata.obs.leiden.replace({'6': '5', '7': '5',})

In [None]:
sc.pp.neighbors(adata, use_rep = 'gcn_he')

In [None]:
sc.tl.leiden(adata, resolution=0.05, flavor="igraph", n_iterations=2)

In [None]:
# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)

# Make mapping: old → new (ranked by size)
mapping = {old: str(new) for new, old in enumerate(counts.index)}

# Apply mapping
adata.obs['leiden'] = adata.obs['leiden'].map(mapping).astype('category')

# (optional) sort categories by their new numeric label
adata.obs['leiden'].cat.reorder_categories(sorted(adata.obs['leiden'].cat.categories, key=int))

print("Cluster relabeling done ✅")# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)
display(counts)

In [None]:
adata.obs['leiden_gcn_he'] = adata.obs.leiden.replace({'6': '5', '7': '5', '8': '5', })

In [None]:
sc.pp.neighbors(adata, use_rep = 'gcn_rna')

In [None]:
sc.tl.leiden(adata, resolution=0.05, flavor="igraph", n_iterations=2)

In [None]:
# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)

# Make mapping: old → new (ranked by size)
mapping = {old: str(new) for new, old in enumerate(counts.index)}

# Apply mapping
adata.obs['leiden'] = adata.obs['leiden'].map(mapping).astype('category')

# (optional) sort categories by their new numeric label
adata.obs['leiden'].cat.reorder_categories(sorted(adata.obs['leiden'].cat.categories, key=int))

print("Cluster relabeling done ✅")# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)
display(counts)

In [None]:
adata.obs['leiden_gcn_rna'] = adata.obs.leiden.replace({'6': '5', '7': '5', '8': '5',})

In [None]:
sc.pp.neighbors(adata, use_rep = 'gcn_onlyrecon')

In [None]:
sc.tl.leiden(adata, resolution=0.03, flavor="igraph", n_iterations=2)

In [None]:
# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)

# Make mapping: old → new (ranked by size)
mapping = {old: str(new) for new, old in enumerate(counts.index)}

# Apply mapping
adata.obs['leiden'] = adata.obs['leiden'].map(mapping).astype('category')

# (optional) sort categories by their new numeric label
adata.obs['leiden'].cat.reorder_categories(sorted(adata.obs['leiden'].cat.categories, key=int))

print("Cluster relabeling done ✅")# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)
display(counts)

In [None]:
adata.obs['leiden_gcn_onlyrecon'] = adata.obs.leiden.replace({'6': '5', '7': '5', })

In [None]:
sc.pp.neighbors(adata, use_rep = 'gcn_onlyrecon_concat')

In [None]:
sc.tl.leiden(adata, resolution=0.04, flavor="igraph", n_iterations=2)

In [None]:
# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)

# Make mapping: old → new (ranked by size)
mapping = {old: str(new) for new, old in enumerate(counts.index)}

# Apply mapping
adata.obs['leiden'] = adata.obs['leiden'].map(mapping).astype('category')

# (optional) sort categories by their new numeric label
adata.obs['leiden'].cat.reorder_categories(sorted(adata.obs['leiden'].cat.categories, key=int))

print("Cluster relabeling done ✅")# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)
display(counts)

In [None]:
adata.obs['leiden_gcn_onlyrecon_concat'] = adata.obs.leiden.replace({'6': '5', '7': '5', })

In [None]:
sc.pp.neighbors(adata, use_rep = 'gcn_onlyrecon_he')

In [None]:
sc.tl.leiden(adata, resolution=0.05, flavor="igraph", n_iterations=2)

In [None]:
# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)

# Make mapping: old → new (ranked by size)
mapping = {old: str(new) for new, old in enumerate(counts.index)}

# Apply mapping
adata.obs['leiden'] = adata.obs['leiden'].map(mapping).astype('category')

# (optional) sort categories by their new numeric label
adata.obs['leiden'].cat.reorder_categories(sorted(adata.obs['leiden'].cat.categories, key=int))

print("Cluster relabeling done ✅")# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)
display(counts)

In [None]:
adata.obs['leiden_gcn_onlyrecon_he'] = adata.obs.leiden.replace({'6': '5', '7': '5', '8': '5', '9': '5',
                                                                '10': '5',})

In [None]:
sc.pp.neighbors(adata, use_rep = 'gcn_onlyrecon_rna')

In [None]:
sc.tl.leiden(adata, resolution=0.05, flavor="igraph", n_iterations=2)

In [None]:
# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)

# Make mapping: old → new (ranked by size)
mapping = {old: str(new) for new, old in enumerate(counts.index)}

# Apply mapping
adata.obs['leiden'] = adata.obs['leiden'].map(mapping).astype('category')

# (optional) sort categories by their new numeric label
adata.obs['leiden'].cat.reorder_categories(sorted(adata.obs['leiden'].cat.categories, key=int))

print("Cluster relabeling done ✅")# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)
display(counts)

In [None]:
adata.obs['leiden_gcn_onlyrecon_rna'] = adata.obs.leiden.replace({'6': '5', '7': '5', '8': '5',})

In [None]:
sc.pp.neighbors(adata, use_rep = 'banksy')

In [None]:
sc.tl.leiden(adata, resolution=0.1, flavor="igraph", n_iterations=2)

In [None]:
# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)

# Make mapping: old → new (ranked by size)
mapping = {old: str(new) for new, old in enumerate(counts.index)}

# Apply mapping
adata.obs['leiden'] = adata.obs['leiden'].map(mapping).astype('category')

# (optional) sort categories by their new numeric label
adata.obs['leiden'].cat.reorder_categories(sorted(adata.obs['leiden'].cat.categories, key=int))

print("Cluster relabeling done ✅")
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)
display(counts)

In [None]:
adata.obs['leiden_banksy'] = adata.obs.leiden.replace({'6': '5', '7': '5',})

In [None]:
sc.pp.neighbors(adata, use_rep = 'OmiCLIP_text')

In [None]:
sc.tl.leiden(adata, resolution=0.23, flavor="igraph", n_iterations=2)

In [None]:
# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)

# Make mapping: old → new (ranked by size)
mapping = {old: str(new) for new, old in enumerate(counts.index)}

# Apply mapping
adata.obs['leiden'] = adata.obs['leiden'].map(mapping).astype('category')

# (optional) sort categories by their new numeric label
adata.obs['leiden'].cat.reorder_categories(sorted(adata.obs['leiden'].cat.categories, key=int))

print("Cluster relabeling done ✅")
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)
display(counts)

In [None]:
adata.obs['leiden_OmiCLIP_text'] = adata.obs.leiden.replace({'6': '5',})

In [None]:
sc.pp.neighbors(adata, use_rep = 'OmiCLIP_image')

In [None]:
sc.tl.leiden(adata, resolution=0.05, flavor="igraph", n_iterations=2)

In [None]:
# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)

# Make mapping: old → new (ranked by size)
mapping = {old: str(new) for new, old in enumerate(counts.index)}

# Apply mapping
adata.obs['leiden'] = adata.obs['leiden'].map(mapping).astype('category')

# (optional) sort categories by their new numeric label
adata.obs['leiden'].cat.reorder_categories(sorted(adata.obs['leiden'].cat.categories, key=int))

print("Cluster relabeling done ✅")
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)
display(counts)

In [None]:
adata.obs['leiden_OmiCLIP_image'] = adata.obs.leiden.replace({'6': '5', })

In [None]:
sc.pp.neighbors(adata, use_rep = 'nichecompass')

In [None]:
sc.tl.leiden(adata, resolution=0.1, flavor="igraph", n_iterations=2)

In [None]:
# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)

# Make mapping: old → new (ranked by size)
mapping = {old: str(new) for new, old in enumerate(counts.index)}

# Apply mapping
adata.obs['leiden'] = adata.obs['leiden'].map(mapping).astype('category')

# (optional) sort categories by their new numeric label
adata.obs['leiden'].cat.reorder_categories(sorted(adata.obs['leiden'].cat.categories, key=int))

print("Cluster relabeling done ✅")
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)
display(counts)

In [None]:
adata.obs['leiden_nichecompass'] = adata.obs.leiden.replace({'6': '5',})

In [None]:
sc.pp.neighbors(adata, use_rep = 'nicheformer')

In [None]:
sc.tl.leiden(adata, resolution=0.2, flavor="igraph", n_iterations=2)

In [None]:
# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)

# Make mapping: old → new (ranked by size)
mapping = {old: str(new) for new, old in enumerate(counts.index)}

# Apply mapping
adata.obs['leiden'] = adata.obs['leiden'].map(mapping).astype('category')

# (optional) sort categories by their new numeric label
adata.obs['leiden'].cat.reorder_categories(sorted(adata.obs['leiden'].cat.categories, key=int))

print("Cluster relabeling done ✅")
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)
display(counts)

In [None]:
adata.obs['leiden_nicheformer'] = adata.obs.leiden

In [None]:
sc.pp.neighbors(adata, use_rep = 'scgptspatial')

In [None]:
sc.tl.leiden(adata, resolution=0.1, flavor="igraph", n_iterations=2)

In [None]:
# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)

# Make mapping: old → new (ranked by size)
mapping = {old: str(new) for new, old in enumerate(counts.index)}

# Apply mapping
adata.obs['leiden'] = adata.obs['leiden'].map(mapping).astype('category')

# (optional) sort categories by their new numeric label
adata.obs['leiden'].cat.reorder_categories(sorted(adata.obs['leiden'].cat.categories, key=int))

print("Cluster relabeling done ✅")
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)
display(counts)

In [None]:
adata.obs['leiden_scgptspatial'] = adata.obs.leiden.replace({'6': '5', })

In [None]:
#adata.layers['counts'] = adata.X.copy()

#sc.pp.normalize_total(adata, target_sum=1e4)
#sc.pp.log1p(adata)

In [None]:
sc.tl.pca(adata)

sc.pp.neighbors(adata, use_rep='X_pca')

In [None]:
sc.tl.leiden(adata, resolution=0.25, flavor="igraph", n_iterations=2)

In [None]:
# Get cluster sizes
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)

# Make mapping: old → new (ranked by size)
mapping = {old: str(new) for new, old in enumerate(counts.index)}

# Apply mapping
adata.obs['leiden'] = adata.obs['leiden'].map(mapping).astype('category')

# (optional) sort categories by their new numeric label
adata.obs['leiden'].cat.reorder_categories(sorted(adata.obs['leiden'].cat.categories, key=int))

print("Cluster relabeling done ✅")
counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)
display(counts)

In [None]:
adata.obs['leiden_scanpy'] = adata.obs.leiden

In [None]:
adata.obs.to_csv('benchmark_ovarian_adata_obs_cl6.csv')

In [None]:
adata_obs= pd.read_csv('benchmark_ovarian_adata_obs_cl6.csv', index_col=0)

adata.obs = adata_obs

In [None]:
clustering_keys = {
    'SpatialFusion': 'leiden_gcn',
    'SpatialFusion (concat)': 'leiden_gcn_concat',
    'SpatialFusion (H&E)': 'leiden_gcn_he',
    'SpatialFusion (RNA)': 'leiden_gcn_rna',
    'SpatialFusion (recon)': 'leiden_gcn_onlyrecon',
    'SpatialFusion (recon concat)': 'leiden_gcn_onlyrecon_concat',
    'SpatialFusion (recon H&E)': 'leiden_gcn_onlyrecon_he',
    'SpatialFusion (recon RNA)': 'leiden_gcn_onlyrecon_rna',
    'NicheCompass': 'leiden_nichecompass',
    'BANKSY': 'leiden_banksy',
    'Nicheformer': 'leiden_nicheformer',
    'scGPT-spatial': 'leiden_scgptspatial',
    'OmiCLIP text': 'leiden_OmiCLIP_text',
    'OmiCLIP image': 'leiden_OmiCLIP_image',
    'Scanpy': 'leiden_scanpy',
}

results_df = compute_all_metrics(adata, clustering_keys)


In [None]:
import matplotlib
matplotlib.rcParams['svg.fonttype'] = 'none'
plot_benchmark_heatmap(results_df, title="OVCA Benchmark",
                       savefig='../../../SpatialFusion/results/figures_Fig2/OVCA_cl6_benchmark.svg')