In [None]:
import scanpy as sc
import numpy as np
import pandas as pd

In [None]:
from spatialfusion.embed.embed import load_gcn, gcn_embeddings_from_joint

In [None]:
import pathlib as pl

In [None]:
from tqdm.notebook import tqdm

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.decomposition import PCA

# Full GCN

Load configuration that was used for the runs

In [None]:
from omegaconf import OmegaConf
cfg = OmegaConf.load("../../results/checkpoint_dir_gcn/checkpoint_dir_gcn/gcn_20251022-170459_5a4f4f64/config_5a4f4f64.yaml")

In [None]:
datapath = pl.Path(cfg.dataset.datapath)

In [None]:
adata_by_sample = {}
for sample in tqdm(cfg.dataset.test_samples):
    adata_by_sample[sample] = sc.read_h5ad(datapath / sample / 'adata.h5ad')

In [None]:
z_joint_df = pd.read_parquet('../../results/embeddings_ae/full-AE-output-model-3085dad0/z_joint_test.parquet')

In [None]:
gcn_model_dir = pl.Path('../../spatialfusion/data/checkpoint_dir_gcn/')

In [None]:
# Suppose you concatenated AE joints across samples (index = cell ids from all samples)
# and you have each sample's adata in a dict.
# Keys must match the sample names you want in the output metadata.

gcn_model = load_gcn(gcn_model_dir / "spatialfusion-full-gcn.pt", in_dim=z_joint_df.shape[1], device="cuda")

gcn_emb_df = gcn_embeddings_from_joint(
    gcn_model=gcn_model,
    z_joint=z_joint_df,                # one big joint embedding over all cells
    adata_by_sample=adata_by_sample,   # per-sample AnnData objects
    base_path=".",                     # path for metadata lookups/saves
    device="cuda",
    spatial_key="spatial_he",             # or "spatial_px"
    celltype_key="major_celltype",
    k=30,
)


In [None]:
gcn_emb_df.to_parquet('../../results/embeddings_gcn/spatialfusion-full-gcn/SpatialFusion.parquet')

# GCN without pathway

In [None]:
from omegaconf import OmegaConf
cfg = OmegaConf.load("../../results/checkpoint_dir_gcn/checkpoint_dir_gcn/gcn_20251022-170720_4e2cecfe/config_4e2cecfe.yaml")

In [None]:
datapath = pl.Path(cfg.dataset.datapath)

In [None]:
adata_by_sample = {}
for sample in tqdm(cfg.dataset.test_samples):
    adata_by_sample[sample] = sc.read_h5ad(datapath / sample / 'adata.h5ad')

In [None]:
z_joint_df = pd.read_parquet('../../results/embeddings_ae/full-AE-output-model-3085dad0/z_joint_test.parquet')

In [None]:
gcn_model_dir = pl.Path('../../results/checkpoint_dir_gcn/checkpoint_dir_gcn/gcn_20251022-170720_4e2cecfe/')

In [None]:
# Suppose you concatenated AE joints across samples (index = cell ids from all samples)
# and you have each sample's adata in a dict.
# Keys must match the sample names you want in the output metadata.

gcn_model = load_gcn(gcn_model_dir / "model.pt", in_dim=z_joint_df.shape[1], device="cuda")

gcn_emb_nopathway_df = gcn_embeddings_from_joint(
    gcn_model=gcn_model,
    z_joint=z_joint_df,                # one big joint embedding over all cells
    adata_by_sample=adata_by_sample,   # per-sample AnnData objects
    base_path=".",                     # path for metadata lookups/saves
    device="cuda",
    spatial_key="spatial_he",             # or "spatial_px"
    celltype_key="major_celltype",
    k=30,
)


In [None]:
gcn_emb_nopathway_df

In [None]:
gcn_emb_nopathway_df.to_parquet('../../results/embeddings_gcn/spatialfusion-nopathway-gcn/SpatialFusion.parquet')

# Compare performance

## Helper functions

In [None]:
def run_pca(emb_df, n_components=5):
    drop_cols = ['sample_id',]
    # Add CNiche and TNiche if they exist in the DataFrame
    drop_cols += [col for col in [ 'cell_id', 'X_coord', 'Y_coord', 'CNiche', 'TNiche', 'cellsubtype','celltype', 'cellsubtypes','celltypes', 
                                  'CCL','CXCL','PD-L1','CD86','EGF','CEACAM','VEGF'] if col in emb_df.columns]

    features = emb_df.drop(columns=drop_cols)
    pca = PCA(n_components=n_components)
    pcs = pca.fit_transform(features)
    pc_df = pd.DataFrame(pcs, columns=[f"PC{i+1}" for i in range(n_components)])
    return pd.concat([emb_df.reset_index(drop=True), pc_df], axis=1)


def standardize_pathways(df: pd.DataFrame, method: str = "robust_z", eps: float = 1e-6, tol: float = 1e-3) -> pd.DataFrame:
    """
    Column-wise standardization of pathway scores.
    - 'robust_z': (x - median) / IQR  (safer to outliers / skew)
    - 'z':        (x - mean) / std
    Columns where all values are nearly zero (|x| < tol) are set to 0.
    NaNs and infs are replaced by 0.
    """
    df = df.copy()

    # Detect "almost-zero" columns
    all_near_zero = (df.abs().max(axis=0) < tol)

    if method == "z":
        mu = df.mean(axis=0)
        sigma = df.std(axis=0).replace(0, np.nan)
        out = (df - mu) / (sigma + eps)
    else:  # robust z-score
        med = df.median(axis=0)
        q1 = df.quantile(0.25, axis=0)
        q3 = df.quantile(0.75, axis=0)
        iqr = (q3 - q1).replace(0, np.nan)
        out = (df - med) / (iqr + eps)

    # Force "uninformative" pathways to 0s after scaling
    out.loc[:, all_near_zero] = 0.0

    # Clean up numerical edge cases
    out = out.replace([np.inf, -np.inf], np.nan).fillna(0.0).astype(np.float32)
    return out

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
import math

def plot_pca_grid(pc_df, color_by_list, max_samples=None, group_by='patient_id', cmap="vlag", ncols=2, savefig=None):
    """
    Plot PCA scatterplots for multiple features on a grid (e.g., 2 columns).
    Styled for publication-quality figures (e.g., Nature Genetics).

    Parameters
    ----------
    pc_df : pd.DataFrame
        Must contain columns 'PC1', 'PC2', and each variable in color_by_list.
    color_by_list : list of str
        Columns to color points by (categorical or continuous).
    max_samples : int, optional
        Max number of samples per group (defined by group_by).
    group_by : str
        Column to balance subsampling across.
    cmap : str
        Colormap for continuous variables (default 'vlag').
    """

    df = pc_df.copy()

    # --- Subsample evenly across groups ---
    if max_samples is not None and group_by in df.columns:
        df = (
            df.groupby(group_by, group_keys=False)
              .apply(lambda g: g.sample(min(len(g), max_samples), random_state=42),
                     include_groups=False)
              .reset_index(drop=True)
        )

    df = df.sample(frac=1, random_state=42).reset_index(drop=True)

    # --- Figure setup ---
    n = len(color_by_list)
    nrows = math.ceil(n / ncols)

    sns.set_context("talk", font_scale=1.2)
    sns.set_style("white")

    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(5.5 * ncols, 5 * nrows))
    axes = np.array(axes).reshape(-1)

    for i, color_by in enumerate(color_by_list):
        ax = axes[i]
        if color_by not in df.columns:
            raise ValueError(f"Column '{color_by}' not found in dataframe.")

        # --- Continuous vs categorical ---
        if np.issubdtype(df[color_by].dtype, np.number):
            norm = plt.Normalize(df[color_by].min(), df[color_by].max())
            sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
            sm.set_array([])

            sc = ax.scatter(
                df["PC1"], df["PC2"],
                c=df[color_by], cmap=cmap, norm=norm,
                s=15, alpha=0.8, edgecolor='none'
            )

            cbar = plt.colorbar(sm, ax=ax, fraction=0.046, pad=0.04)
            cbar.ax.tick_params(labelsize=25)
            cbar.set_label(color_by, fontsize=25)

        else:
            sns.scatterplot(
                data=df,
                x="PC1", y="PC2",
                hue=color_by,
                s=15, alpha=0.8,
                ax=ax, linewidth=0, legend=False
            )

        # --- Titles and axes ---
        ax.set_title(f"{color_by}", fontsize=30,)
        ax.set_xlabel("PC1", fontsize=25)
        ax.set_ylabel("PC2", fontsize=25)

        # --- Remove ticks but keep axis labels ---
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_xticklabels([])
        ax.set_yticklabels([])

        # --- Remove spines (Nature style) ---
        for spine in ["top", "right", "left", "bottom"]:
            ax.spines[spine].set_visible(False)

        # --- Clean layout ---
        ax.grid(False)
        ax.tick_params(axis='both', which='major', labelsize=9)

    # Hide unused axes
    for j in range(i + 1, len(axes)):
        axes[j].set_visible(False)

    plt.tight_layout()
    if savefig is not None:
        fig.savefig(savefig, dpi=200, bbox_inches='tight')
    plt.show()


## Analysis

Download the two embeddings + the pathway per sample to compare

In [None]:
gcn_emb_df = pd.read_parquet('../../results/embeddings_gcn/spatialfusion-full-gcn/SpatialFusion.parquet')

In [None]:
gcn_emb_nopathway_df = pd.read_parquet('../../results/embeddings_gcn/spatialfusion-nopathway-gcn/SpatialFusion.parquet')

In [None]:
from omegaconf import OmegaConf
cfg = OmegaConf.load("/../../checkpoint_dir_gcn/checkpoint_dir_gcn/gcn_20251022-170720_4e2cecfe/config_4e2cecfe.yaml")

In [None]:
datapath = pl.Path(cfg.dataset.datapath)

In [None]:
all_pw = []
for patient in cfg['dataset']['test_samples']:
    pw = pd.read_parquet(datapath / patient / 'pathway_activation.parquet')
    pw = standardize_pathways(pw)
    all_pw.append(pw)
all_pw = pd.concat(all_pw)

In [None]:
pc_df = run_pca(gcn_emb_df)

In [None]:
pc_nopath_df = run_pca(gcn_emb_nopathway_df)

In [None]:
plot_df = pd.concat([pc_df.set_index('cell_id')[[f'PC{i}' for i in range(1,6)]],all_pw],axis=1)

In [None]:
figdir = '../../results/figures_Fig1/'
os.makedirs(figdir, exist_ok=True)

In [None]:
plot_pca_grid(plot_df, color_by_list=all_pw.columns, max_samples=100000, group_by='sample_id', cmap="vlag", ncols=5,
              savefig=pl.Path(figdir) / 'pathway_dist.png')

## Quantify organization

In [None]:
import numpy as np
import pandas as pd
import time
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import Ridge
from sklearn.model_selection import KFold
import pynndescent

# ----------------------------
# kNN via pynndescent (approx)
# ----------------------------
def _pynndescent_knn(X_f32, k=15):
    print(f"[kNN] Building pynndescent index with k={k} ...")
    t0 = time.time()
    index = pynndescent.NNDescent(
        X_f32, n_neighbors=k+1, metric="euclidean", random_state=0, n_jobs=-1
    )
    idx, dists = index.neighbor_graph  # distances (not squared)
    print(f"[kNN] Done. Took {time.time()-t0:.2f} sec")
    # Drop self-neighbor at [:,0]
    return idx[:, 1:], dists[:, 1:]

# ----------------------------
# Moran's I (streamed)
# ----------------------------
def morans_I_stream(y_values: np.ndarray, knn_idx: np.ndarray, knn_d2: np.ndarray,
                    inverse_distance=True, eps=1e-12):
    """
    Compute Moran's I from neighbor lists only (no sparse W).
    knn_idx: (n, k) neighbor indices
    knn_d2:  (n, k) squared distances
    """
    n, k = knn_idx.shape
    z = y_values - y_values.mean()
    z2_sum = float((z ** 2).sum())

    if inverse_distance:
        w = 1.0 / (np.sqrt(knn_d2) + eps)
    else:
        w = np.ones_like(knn_d2, dtype=np.float32)

    z_i = z[:, None]
    z_neighbors = z[knn_idx]
    num_directed = float(np.sum(w * z_i * z_neighbors))
    S0 = float(np.sum(w))
    I = (n / S0) * (num_directed / z2_sum)
    return I, S0

# ----------------------------
# CV R^2 on pre-standardized X (reusing splits)
# ----------------------------
def cv_r2_linear_from_X(X_std: np.ndarray, y: np.ndarray, splits, alpha=1.0):
    """
    X_std: standardized features (n, d)
    y:     target (n,)
    splits: iterable of (train_idx, test_idx)
    """
    scores = []
    for i, (tr, te) in enumerate(splits, 1):
        fold_t0 = time.time()
        model = Ridge(alpha=alpha, solver="sag", random_state=0,
                      max_iter=10_000, tol=1e-3)
        model.fit(X_std[tr], y[tr])
        yhat = model.predict(X_std[te])
        ss_res = float(np.sum((y[te] - yhat) ** 2))
        ss_tot = float(np.sum((y[te] - np.mean(y[te])) ** 2))
        r2 = 1.0 - ss_res / (ss_tot + 1e-12)
        scores.append(r2)
        print(f"[CV] Fold {i} R^2={r2:.4f} (took {time.time()-fold_t0:.2f} sec)")
    return float(np.mean(scores)), float(np.std(scores))

# ----------------------------
# Shared subsample (same cells for both embeddings & all Y columns)
# ----------------------------
def _shared_subsample_indices_multi(embeddings_df: pd.DataFrame,
                                    embeddings_df_old: pd.DataFrame,
                                    Y_df: pd.DataFrame,
                                    subsample_n: int,
                                    seed: int = 0):
    # Only keep rows where ALL Y columns are non-null to keep neighbors valid for every variable
    Y_valid_idx = Y_df.dropna(how="any").index
    common_all = embeddings_df.index.intersection(embeddings_df_old.index).intersection(Y_valid_idx)
    n_common = len(common_all)
    if n_common == 0:
        raise ValueError("No overlapping cell names among embeddings_df, embeddings_df_old, and non-null rows of Y_df.")
    take_n = min(subsample_n, n_common)
    rng = np.random.default_rng(seed)
    take_idx = rng.choice(n_common, size=take_n, replace=False)
    sub_index = common_all.take(take_idx)
    return sub_index

# ----------------------------
# End-to-end with shared subsample for MULTIPLE variables
# ----------------------------
def compare_embeddings_subsampled_samecells_multi(embeddings_df: pd.DataFrame,
                                                  embeddings_df_old: pd.DataFrame,
                                                  Y_df: pd.DataFrame,
                                                  k: int = 15,
                                                  subsample_n: int = 100_000,
                                                  inverse_distance: bool = True,
                                                  seed: int = 0,
                                                  n_splits: int = 5,
                                                  alpha: float = 1.0):
    """
    Subsamples the SAME ~subsample_n cells from the triple intersection of
    {embeddings_df, embeddings_df_old, Y_df (non-null across all columns)} and runs, for each Y column:
      - pynndescent kNN (once per embedding)
      - Moran's I
      - n-fold Ridge CV R^2 (reusing the same splits across variables)
    Returns a nested dict and a tidy DataFrame for convenience.
    """
    # ---------- select shared subsample ----------
    print(f"[Setup] Selecting shared subsample (up to {subsample_n:,} cells, seed={seed}) ...")
    t0 = time.time()
    sub_index = _shared_subsample_indices_multi(embeddings_df, embeddings_df_old, Y_df, subsample_n, seed)
    print(f"[Setup] Chosen {len(sub_index):,} shared cells in {time.time()-t0:.2f} sec.")

    # Slice Y once
    Y_sub = Y_df.loc[sub_index].astype(np.float32)
    y_cols = list(Y_sub.columns)
    print(f"[Setup] {len(y_cols)} variables: {y_cols}")

    # Prepare CV splits ONCE and reuse across variables and embeddings
    kf = KFold(n_splits=n_splits, shuffle=True, random_state=0)
    # We need the length n for splits; base it on the subsample size
    n = len(sub_index)
    dummy_indices = np.arange(n)
    splits = list(kf.split(dummy_indices))

    results = {"new": {}, "old": {}}
    tidy_rows = []

    # ---------- per-embedding pipeline (neighbors once, reuse for all y cols) ----------
    for tag, df in [("new", embeddings_df), ("old", embeddings_df_old)]:
        print(f"\n=== Processing embedding: {tag} ===")
        t_total = time.time()

        # Extract X for the shared subsample
        print("[Data] Extracting X for shared subsample ...")
        t1 = time.time()
        X = df.loc[sub_index].to_numpy(dtype=np.float32)
        print(f"[Data] Done in {time.time()-t1:.2f} sec. Using {X.shape[0]:,} cells and {X.shape[1]} dims.")

        # Standardize features once
        print("[Preproc] Standardizing features ...")
        t1 = time.time()
        scaler = StandardScaler(with_mean=True, with_std=True)
        X_std = scaler.fit_transform(X).astype(np.float32)
        print(f"[Preproc] Done. Took {time.time()-t1:.2f} sec")

        # Neighbors once per embedding
        idx, dists = _pynndescent_knn(X_std, k=k)
        d2 = (dists.astype(np.float32)) ** 2  # squared distances for Moran's I

        # For each variable column
        for c_idx, col in enumerate(y_cols, 1):
            print(f"\n[{tag} | Var {c_idx}/{len(y_cols)}] {col}")
            yv = Y_sub[col].to_numpy(dtype=np.float32)

            # Moran's I
            t_m = time.time()
            I, S0 = morans_I_stream(yv, idx, d2, inverse_distance=inverse_distance)
            print(f"[Moran] I={I:.6f} (took {time.time()-t_m:.2f} sec)")

            # CV R^2 (reuse splits)
            print("[CV] Cross-validated Ridge R^2 ...")
            t_cv = time.time()
            r2_mean, r2_sd = cv_r2_linear_from_X(X_std, yv, splits, alpha=alpha)
            print(f"[CV] mean={r2_mean:.4f}, sd={r2_sd:.4f} (took {time.time()-t_cv:.2f} sec)")

            # Store
            results[tag][col] = {
                "morans_I": float(I),
                "S0": float(S0),
                "cv_r2_mean": float(r2_mean),
                "cv_r2_sd": float(r2_sd),
                "n_cells_subsample": int(X_std.shape[0]),
                "k": int(k),
                "weights": "inverse_distance" if inverse_distance else "binary",
                "seed": int(seed),
                "shared_cells": True,
                "n_splits": int(n_splits),
                "alpha": float(alpha),
            }
            tidy_rows.append({
                "embedding": tag,
                "variable": col,
                "morans_I": float(I),
                "S0": float(S0),
                "cv_r2_mean": float(r2_mean),
                "cv_r2_sd": float(r2_sd),
                "n_cells_subsample": int(X_std.shape[0]),
                "k": int(k),
                "weights": "inverse_distance" if inverse_distance else "binary",
                "seed": int(seed),
                "shared_cells": True,
                "n_splits": int(n_splits),
                "alpha": float(alpha),
            })

        print(f"=== Done with embedding: {tag}. Total time {time.time()-t_total:.2f} sec ===")

    tidy_df = pd.DataFrame(tidy_rows).set_index(["embedding", "variable"]).sort_index()
    return results, tidy_df, sub_index


In [None]:
results, tidy_df, shared_cells = compare_embeddings_subsampled_samecells_multi(
    embeddings_df=gcn_emb_df.set_index('cell_id').loc[:,['0','1','2','3','4','5','6','7','8','9',]],
     embeddings_df_old=gcn_emb_nopathway_df.set_index('cell_id').loc[:,['0','1','2','3','4','5','6','7','8','9',]],
    Y_df=all_pw,                 # <-- DataFrame of variables (index = cell names)
    k=15,
    subsample_n=300_000,
    inverse_distance=True,
    seed=42,
    n_splits=5,
    alpha=1.0
)

print(tidy_df)                 # one row per (embedding, variable)
# shared_cells is the index of the cells used everywhere


In [None]:
tidy_df.to_csv('../../results/figures_Fig1/results_comparison.csv')

In [None]:
tidy_df

In [None]:
# Make wide comparison: variables as rows, embeddings as columns
r2_wide = tidy_df["cv_r2_mean"].unstack(level="embedding")
r2_wide

In [None]:
r2_wide.to_csv('../../results/figures_Fig1/r2_pathway_comparison.csv')