In [None]:
import scanpy as sc
import numpy as np
import pandas as pd
import pathlib as pl

import torch
from torch.utils.data import TensorDataset, DataLoader

from tqdm import tqdm

In [None]:
from spatialfusion.models.baseline_multi_ae import PairedAE  # your custom model class
from spatialfusion.utils.baseline_ae_data_loader import load_and_preprocess_sample_baseline

In [None]:
BATCH_SIZE = 128  # Tune for your GPU/CPU capacity
DEFAULT_LABEL = "unknown"

# Define function to extract baseline embeddings

In [None]:
def extract_embeddings_baseline(model, sample_list,
                                base_path, raw_path, used_genes,
                                device="cpu", image_size=224):
    all_z1, all_z2, all_zjoint = [], [], []
    all_celltypes, all_samples = [], []

    model.eval()

    with torch.no_grad():
        for sample in tqdm(sample_list, desc="Samples"):
            try:
                # Load data
                img_tensor, gex_tensor, cell_ids = load_and_preprocess_sample_baseline(
                    sample_name=sample,
                    base_path=base_path,
                    raw_path=raw_path,
                    SOFT_UNION_GENE_LIST=used_genes,
                    max_cells=100000,
                    image_size=image_size
                )

                # Deduplicate cell_ids
                cell_ids = pd.Index(cell_ids)
                _, unique_idx = np.unique(cell_ids, return_index=True)
                unique_idx = np.sort(unique_idx)
                cell_ids_unique = cell_ids[unique_idx]
                img_tensor = img_tensor[unique_idx]
                gex_tensor = gex_tensor[unique_idx]

                # Encode
                dataset = TensorDataset(img_tensor.float(), gex_tensor.float())
                loader = DataLoader(dataset, batch_size=BATCH_SIZE)

                z1_list, z2_list = [], []
                for x1, x2 in tqdm(loader, desc=f"Encoding {sample}", leave=False):
                    x1, x2 = x1.to(device), x2.to(device)
                    z1_list.append(model.encoder1(x1).detach().cpu())
                    z2_list.append(model.encoder2(x2).detach().cpu())

                z1 = torch.cat(z1_list).numpy()
                z2 = torch.cat(z2_list).numpy()
                z_joint = (z1 + z2) / 2

                # Sanity check
                if z1.shape[0] != len(cell_ids_unique):
                    print(f"[DEBUG] z1 shape: {z1.shape}, expected: {len(cell_ids_unique)}")
                    raise ValueError("Mismatch between z1 and cell_ids length")

                # Load labels
                ct_path = pl.Path(base_path) / sample / "celltypes.csv"
                adata_path = pl.Path(base_path) / sample / "adata.h5ad"
                label_source = None

                if ct_path.exists():
                    label_source = pd.read_csv(ct_path, index_col=0).iloc[:, 0]
                elif pl.Path(adata_path).exists():
                    adata = sc.read_h5ad(adata_path)
                    if "celltypes" in adata.obs.columns:
                        label_source = adata.obs["celltypes"]

                # Align labels
                if label_source is not None:
                    label_source = label_source[~label_source.index.duplicated(keep="first")]
                    valid_ids = cell_ids_unique.intersection(label_source.index)

                    if len(valid_ids) == 0:
                        print(f"[Warning] No label overlap for {sample}. Skipping.")
                        continue

                    idx = cell_ids_unique.get_indexer(valid_ids)
                    if np.any(idx < 0) or np.max(idx) >= len(z1):
                        print(f"[DEBUG] Sample {sample}")
                        print(f"[DEBUG] idx min/max: {idx.min()} / {idx.max()}, z1 shape: {z1.shape}")
                        raise IndexError("Invalid index range after label alignment")

                    z1 = z1[idx]
                    z2 = z2[idx]
                    z_joint = z_joint[idx]
                    labels = label_source.reindex(valid_ids).fillna(DEFAULT_LABEL).to_numpy()
                    cell_ids_final = valid_ids
                else:
                    labels = np.full(len(cell_ids_unique), DEFAULT_LABEL, dtype=object)
                    cell_ids_final = cell_ids_unique

                # Store
                all_z1.append(pd.DataFrame(z1, index=cell_ids_final))
                all_z2.append(pd.DataFrame(z2, index=cell_ids_final))
                all_zjoint.append(pd.DataFrame(z_joint, index=cell_ids_final))
                all_celltypes.append(labels)
                all_samples.append([sample] * len(cell_ids_final))

            except Exception as e:
                print(f"[Warning] Skipping {sample} due to error: {e}")
                continue

    # Final concatenation
    return (
        pd.concat(all_z1),
        pd.concat(all_z2),
        pd.concat(all_zjoint),
        np.concatenate(all_celltypes),
        np.concatenate(all_samples)
    )


# Get embeddings for the baseline AE

In [None]:
from omegaconf import OmegaConf
cfg = OmegaConf.load("../../results/logs/baseline_ae/run_20251022-155824_28f7875f/config_28f7875f.yaml")

used_genes = pd.read_csv("../../results/logs/baseline_ae/run_20251022-155824_28f7875f/used_genes.txt", header=None).values.ravel()

d2_dim = len(used_genes)

model = PairedAE(
    d2_dim=d2_dim,  # or infer from data if not stored
    latent_dim=cfg.training.latent_dim,
    resnet_backbone=cfg.training.resnet_backbone,
    freeze_resnet=cfg.training.freeze_resnet
)

In [None]:
# --- Load state dict ---
checkpoint_path = "../../results/checkpoint_dir_ae/baseline_ae/paired_model_28f7875f.pt"
model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"))

# --- Move to device ---
device = "cuda:5"
model.to(device)
model.eval()

In [None]:
# Run embedding extraction
z1_baseline, z2_baseline, z_joint_baseline, celltypes_baseline, samples_baseline = extract_embeddings_baseline(
    model,
    sample_list=cfg.dataset.test_samples,
    base_path=cfg.dataset.datapath,
    raw_path=cfg.dataset.rawpath,
    used_genes=used_genes,
    device=device
)

In [None]:
from spatialfusion.utils.embed_ae_utils import save_embeddings_separately

In [None]:
save_embeddings_separately(
    z1_baseline, z2_baseline, z_joint_baseline, celltypes_baseline, samples_baseline, mode='test',
    out_dir="../../results/embeddings_ae/baseline-AE-test-output/"
)


# Load embeddings from comparable multimodal AE 

In [None]:
# Load back z_joint
z_joint = pd.read_parquet("../../results/embeddings_ae/output-model-9e693874-test/z_joint_test.parquet")

In [None]:
z1 = pd.read_parquet("../../results/embeddings_ae/output-model-9e693874-test/z1_test.parquet")

In [None]:
z2 = pd.read_parquet("/../../results/embeddings_ae/output-model-9e693874-test/z2_test.parquet")

In [None]:
import h5py
celltypes = []

with h5py.File("../../results/embeddings_ae/output-model-9e693874-test/metadata_test.h5", "r") as f:
    celltypes = f["celltypes"][:].astype(str)
    samples = f["samples"][:].astype(str)

# Compare performance of AEs

In [None]:
from sklearn.model_selection import StratifiedKFold, StratifiedShuffleSplit
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.metrics import balanced_accuracy_score, f1_score
from sklearn.pipeline import make_pipeline
from tqdm import tqdm
import numpy as np
import pandas as pd

def evaluate_classification_lr(
    z_joint_df: pd.DataFrame,
    celltypes: np.ndarray,
    n_splits: int = 4,
    subsample_size: int | None = None,
    random_state: int = 42,
):
    """
    Evaluate classification of celltypes from embeddings using Stratified K-fold
    Logistic Regression. Optionally subsample first (stratified) and run
    everything on that subset.

    Args:
        z_joint_df: DataFrame of embeddings (cells x features).
        celltypes: Array-like of ground-truth cell type labels.
        n_splits: Number of CV folds (may be reduced if some classes are tiny).
        subsample_size: If set, run on a stratified subsample of this many rows.
        random_state: Seed for reproducibility.

    Returns:
        dict with average 'balanced_accuracy' and 'f1_macro'.
    """
    # Filter out unknown/missing labels
    valid = (celltypes != "unknown") & pd.notnull(celltypes)
    X_full = z_joint_df.iloc[valid].to_numpy(copy=False)
    y_full = np.asarray(celltypes)[valid]

    # Need at least 2 classes
    if len(np.unique(y_full)) < 2:
        return {"balanced_accuracy": np.nan, "f1_macro": np.nan}

    # ---- Optional stratified subsample ----
    if subsample_size is not None and subsample_size < len(X_full):
        n_keep = int(subsample_size)
        sss = StratifiedShuffleSplit(
            n_splits=1, train_size=n_keep, random_state=random_state
        )
        (keep_idx, _), = sss.split(X_full, y_full)
        X = X_full[keep_idx]
        y = y_full[keep_idx]
    else:
        X, y = X_full, y_full

    # Encode labels after subsampling
    le = LabelEncoder()
    y_encoded = le.fit_transform(y)

    # Ensure we can do stratified K-fold: each class must have >= n_splits samples
    _, counts = np.unique(y_encoded, return_counts=True)
    min_class_count = int(counts.min())
    if min_class_count < 2:
        # Too few samples in at least one class to run CV
        return {"balanced_accuracy": np.nan, "f1_macro": np.nan}

    n_splits_eff = min(n_splits, min_class_count)
    if n_splits_eff < n_splits:
        print(f"[info] Reducing n_splits from {n_splits} to {n_splits_eff} "
              f"because the smallest class has only {min_class_count} samples.")

    skf = StratifiedKFold(n_splits=n_splits_eff, shuffle=True, random_state=random_state)

    bac_scores, f1_scores = [], []

    for train_idx, test_idx in tqdm(skf.split(X, y_encoded), total=n_splits_eff):
        X_train, X_test = X[train_idx], X[test_idx]
        y_train, y_test = y_encoded[train_idx], y_encoded[test_idx]

        clf = make_pipeline(
            StandardScaler(),
            LogisticRegression(
                max_iter=1000,
                class_weight="balanced",
                random_state=random_state,
                solver="lbfgs",
                multi_class="auto",
            ),
        )
        clf.fit(X_train, y_train)
        y_pred = clf.predict(X_test)

        bac_scores.append(balanced_accuracy_score(y_test, y_pred))
        f1_scores.append(f1_score(y_test, y_pred, average="macro"))

    return {
        "balanced_accuracy": float(np.mean(bac_scores)),
        "f1_macro": float(np.mean(f1_scores)),
    }


In [None]:
lr_evaluation_AE = evaluate_classification_lr(z_joint_df=z_joint,
                                celltypes=celltypes,
                                n_splits=4, subsample_size=400_000,
                                random_state=42)

In [None]:
lr_evaluation_AE

In [None]:
lr_evaluation_baseline = evaluate_classification_lr(z_joint_df=z_joint_baseline,
                                celltypes=celltypes_baseline,
                                n_splits=4, subsample_size=400_000,
                                random_state=42)

In [None]:
lr_evaluation_baseline

In [None]:
import numpy as np
import pandas as pd
from sklearn.metrics import silhouette_samples
from sklearn.preprocessing import LabelEncoder

def batch_asw_fast(
    embedding: np.ndarray,
    batch_labels: np.ndarray,
    celltype_labels: np.ndarray,
    metric: str = 'euclidean',
    max_cells: int = 20000,
    scale: bool = True,
    random_state: int = 42,
):
    """
    Compute batch average silhouette width (batch ASW) with stratified subsampling.
    Designed for very large datasets (~millions of cells).

    Parameters
    ----------
    embedding : (n_cells, n_dims) array
        Low-dimensional embedding of cells (e.g., PCA, UMAP, etc.).
    batch_labels : (n_cells,) array-like
        Batch assignment for each cell.
    celltype_labels : (n_cells,) array-like
        Cell-type or biological group assignment for each cell.
    metric : str
        Distance metric for silhouette_samples (default 'euclidean').
    max_cells : int
        Max number of cells to use for computing silhouettes.
    scale : bool
        Whether to scale as 1 - |silhouette| (so higher = better mixing).
    random_state : int
        Random seed for reproducibility.

    Returns
    -------
    float
        The batch ASW score (higher = better batch mixing).
    """

    n_cells = embedding.shape[0]
    if n_cells <= max_cells:
        idx_keep = np.arange(n_cells)
    else:
        # --- Stratified subsampling: preserve cell-type Ã— batch proportions ---
        df = pd.DataFrame({
            "celltype": celltype_labels,
            "batch": batch_labels
        })
        df["strata"] = df["celltype"].astype(str) + "_" + df["batch"].astype(str)

        # target number of samples per stratum, proportional to its frequency
        np.random.seed(random_state)
        strata_counts = df["strata"].value_counts()
        frac = min(1.0, max_cells / n_cells)
        n_per_strata = np.maximum(1, (strata_counts * frac).astype(int))

        idx_keep = (
            df.groupby("strata", group_keys=False)
              .apply(lambda g: g.sample(n=min(len(g), n_per_strata[g.name]),
                                        random_state=random_state))
              .index.values
        )

    # Subset data
    emb_sub = embedding[idx_keep]
    batch_sub = np.asarray(batch_labels)[idx_keep]
    celltype_sub = np.asarray(celltype_labels)[idx_keep]

    # Encode labels for sklearn
    batch_sub = LabelEncoder().fit_transform(batch_sub)
    celltype_sub = np.asarray(celltype_sub)

    # Compute silhouette for batches
    sil_vals = silhouette_samples(emb_sub, batch_sub, metric=metric)

    # Mean |silhouette| per cell type
    df_sil = pd.DataFrame({
        "celltype": celltype_sub,
        "silhouette": sil_vals
    })

    per_ct = (
        df_sil.groupby("celltype")["silhouette"]
              .apply(lambda x: np.mean(np.abs(x)))
              .values
    )

    avg_abs_sil = np.mean(per_ct)

    return 1.0 - avg_abs_sil if scale else avg_abs_sil


In [None]:
bASW = batch_asw_fast(
    embedding=z_joint.values,
    batch_labels=samples,
    celltype_labels=celltypes,
    metric='euclidean',
     max_cells=100000,
    scale=True
)

In [None]:
bASW

In [None]:
bASW = batch_asw_fast(
    embedding=z_joint_baseline.values,
    batch_labels=samples_baseline,
    celltype_labels=celltypes_baseline,
    metric='euclidean',
     max_cells=100000,
    scale=True
)

In [None]:
bASW

In [None]:
import numpy as np
import pandas as pd
from sklearn.neighbors import NearestNeighbors
from scipy.stats import chi2

def kbet(
    embedding: np.ndarray,
    batch_labels: np.ndarray,
    k: int = 50,
    max_cells: int = 20000,
    alpha: float = 0.05,
    random_state: int = 42,
) -> float:
    """
    Compute the kBET (k-nearest-neighbor Batch-Effect Test) score
    following the implementation concept in the scIB benchmark.

    Parameters
    ----------
    embedding : (n_cells, n_dims) array
        Low-dimensional embedding of cells (e.g., PCA or UMAP).
    batch_labels : (n_cells,) array-like
        Batch assignment for each cell.
    k : int
        Number of neighbors for the local test (default 50).
    max_cells : int
        Max number of cells to subsample for runtime control.
    alpha : float
        Significance level for chi-squared test (default 0.05).
    random_state : int
        Random seed.

    Returns
    -------
    float
        kBET acceptance rate = 1 - rejection_rate (higher = better mixing).
    """

    n_cells = embedding.shape[0]
    rng = np.random.default_rng(random_state)
    if n_cells > max_cells:
        idx = rng.choice(n_cells, size=max_cells, replace=False)
    else:
        idx = np.arange(n_cells)

    emb_sub = np.asarray(embedding)[idx]
    batch_sub = np.asarray(batch_labels)[idx]

    unique_batches, batch_counts = np.unique(batch_sub, return_counts=True)
    batch_probs = batch_counts / batch_counts.sum()
    n_batches = len(unique_batches)

    # Build neighbor graph
    nbrs = NearestNeighbors(n_neighbors=k + 1, metric='euclidean').fit(emb_sub)
    knn_idx = nbrs.kneighbors(return_distance=False)[:, 1:]

    # Expected batch counts in each neighborhood
    expected = k * batch_probs
    chi2_threshold = chi2.ppf(1 - alpha, df=n_batches - 1)

    rejections = 0
    for i in range(len(emb_sub)):
        neighbor_batches = batch_sub[knn_idx[i]]
        obs_counts = np.array([np.sum(neighbor_batches == b) for b in unique_batches])
        chi2_stat = np.sum((obs_counts - expected) ** 2 / expected)
        if chi2_stat > chi2_threshold:
            rejections += 1

    rejection_rate = rejections / len(emb_sub)
    return 1.0 - rejection_rate


In [None]:
kbet_ae = kbet(
    embedding=z_joint.values,
    batch_labels=samples,
    k=50,
    max_cells=100000,
    alpha=0.05,
    random_state=42,
) 

In [None]:
kbet_ae

In [None]:
kbet_baseline = kbet(
    embedding=z_joint_baseline.values,
    batch_labels=samples_baseline,
    k=50,
    max_cells=100000,
    alpha=0.05,
    random_state=42,
) 

In [None]:
kbet_baseline

# Plot latent spaces

In [None]:
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd

def plot_embeddings(
    z1, z2, z_joint, labels, samples, label_title,
    palette="tab10", seed=42, max_cells=500_000, savefig=None,
):

    rng = np.random.default_rng(seed)

    df = pd.DataFrame({"sample": samples, "label": labels}, index=z1.index)
    df["idx"] = np.arange(len(df))
    grouped = df.groupby("sample")

    # Determine how many cells per sample to keep
    total_samples = len(grouped)
    max_per_sample = max_cells // max(1, total_samples)

    # Subsample per sample group
    selected_indices = []
    for _, group in grouped:
        n = min(len(group), max_per_sample)
        selected_indices.extend(rng.choice(group["idx"].values, size=n, replace=False))

    # Subset everything
    z1 = z1.iloc[selected_indices]
    z2 = z2.iloc[selected_indices]
    z_joint = z_joint.iloc[selected_indices]
    labels = np.array(labels)[selected_indices]
    samples = np.array(samples)[selected_indices]

    # Apply PCA
    pca = PCA(n_components=2)
    z1_pca = pca.fit_transform(z1.values)
    z2_pca = pca.fit_transform(z2.values)
    z_joint_pca = pca.fit_transform(z_joint.values)

    # Shuffle for plot order
    perm = rng.permutation(len(labels))
    z1_pca = z1_pca[perm]
    z2_pca = z2_pca[perm]
    z_joint_pca = z_joint_pca[perm]
    labels = labels[perm]

    # Plot
    fig, axes = plt.subplots(1, 3, figsize=(16, 5))
    point_size = 5  # keep a single source of truth for both scatter and legend markers

    def _strip_spines(ax):
        # Remove ALL spines (no plot borders)
        for spine in ax.spines.values():
            spine.set_visible(False)
        # No ticks and equal aspect
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_aspect("equal")

    def plot_scatter(ax, emb, title, show_legend=False):
        sns.scatterplot(
            x=emb[:, 0],
            y=emb[:, 1],
            hue=labels,
            palette=palette,
            s=point_size,
            ax=ax,
            linewidth=0,
            alpha=0.8,
            legend="full" if show_legend else False,
        )
        #ax.set_title(title, fontsize=25)
        ax.set_title('', fontsize=25)
        _strip_spines(ax)

    plot_scatter(axes[0], z1_pca, "H&E (PCA)")
    plot_scatter(axes[1], z2_pca, "RNA (PCA)")
    plot_scatter(axes[2], z_joint_pca, "Joint (PCA)", show_legend=True)

    # Extract and remove subplot legend
    handles, labels_ = axes[2].get_legend_handles_labels()
    if axes[2].legend_ is not None:
        axes[2].legend_.remove()

    # Make legend marker sizes match the scatter point size
    # (handles from seaborn are PathCollections for the color items)
    for h in handles:
        if hasattr(h, "set_sizes"):
            h.set_sizes([point_size])  # one marker per legend entry

    # Add shared figure legend (text size follows rcParams; markers already matched)
    fig.legend(
        handles,
        labels_,
        loc="center left",
        bbox_to_anchor=(1.01, 0.5),
        title=label_title,
        scatterpoints=1,     # one marker per legend entry
        markerscale=10,     # keep scale = 1 to respect set_sizes above
        frameon=False,       # optional: no border around the legend
        fontsize=16,        # control legend text size here
        title_fontsize=18,
    )

    if savefig is not None:
        fig.savefig(savefig, dpi=200, bbox_inches='tight')

    plt.tight_layout()
    plt.show()


In [None]:
palette = sns.color_palette()

In [None]:
palette_batch = sns.color_palette('muted')

In [None]:
ct_palette = {ct: palette[i] for i, ct in enumerate(np.unique(celltypes))}

In [None]:
batch_palette = {btc: palette_batch[i] for i, btc in enumerate(np.unique(samples))}

In [None]:
plot_embeddings(z1, z2, z_joint, celltypes, samples, label_title="Cell Type", max_cells=1000000, palette=ct_palette,
                savefig='../../../SpatialFusion/results/figures_Fig1/ae_celltype_scatter.png')
plot_embeddings(z1, z2, z_joint, samples, samples, label_title="Sample", max_cells=1000000, palette=batch_palette,
                savefig='../../../SpatialFusion/results/figures_Fig1/ae_batch_scatter.png')

In [None]:
plot_embeddings(z1_baseline, z2_baseline, z_joint_baseline, celltypes_baseline, samples_baseline,
                label_title="Cell Type", max_cells=1000000, palette=ct_palette,
                savefig='../../results/figures_Fig1/baseline_ae_celltype_scatter.png')
plot_embeddings(z1_baseline, z2_baseline, z_joint_baseline, samples_baseline, samples_baseline,
                label_title="Sample", max_cells=1000000, palette=batch_palette,
               savefig='../../results/figures_Fig1/baseline_ae_batch_scatter.png')