In [None]:
import matplotlib
matplotlib.rcParams['svg.fonttype'] = 'none'

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import GroupKFold
from tqdm import tqdm
import numpy as np
import random

# ------------------------------------------------------------
# 0. Utility: patient-level normalization
# ------------------------------------------------------------
def patient_scale(bags):
    """
    bags: dict patient_id -> tensor [N_i, d]
    Returns normalized version.
    """
    scaled = {}
    for pid, X in bags.items():
        mean = X.mean(dim=0, keepdim=True)
        std  = X.std(dim=0, keepdim=True)
        scaled[pid] = (X - mean) / (std + 1e-6)
    return scaled


# ------------------------------------------------------------
# 1. MIL attention model with DROPOUT
# ------------------------------------------------------------
class AttentionMIL(nn.Module):
    def __init__(self, in_dim, hidden_dim, dropout=0.25):
        super().__init__()
        self.a = nn.Linear(in_dim, hidden_dim)
        self.drop = nn.Dropout(dropout)
        self.b = nn.Linear(hidden_dim, 1)

    def forward(self, x):
        h = torch.tanh(self.a(x))
        h = self.drop(h)
        scores = self.b(h)               # [N, 1]
        weights = torch.softmax(scores, dim=0)
        pooled = torch.sum(weights * x, dim=0, keepdim=True)
        return pooled, weights.squeeze()


class MILModel(nn.Module):
    def __init__(self, emb_dim, att_dim=128, cls_dim=64, dropout=0.3):
        super().__init__()
        self.attention = AttentionMIL(emb_dim, att_dim, dropout=dropout)
        self.classifier = nn.Sequential(
            nn.Linear(emb_dim, cls_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(cls_dim, 1)
        )

    def forward(self, bag):
        pooled, att = self.attention(bag)
        logits = self.classifier(pooled)
        return logits, att


# ------------------------------------------------------------
# 2. Multi-sample Dataset + Add Noise
# ------------------------------------------------------------
class MultiSampleRandomBags(Dataset):
    def __init__(self, bags, labels, n_samples=1000, K=5, noise_std=0.02):
        self.bags = [torch.as_tensor(b, dtype=torch.float32) for b in bags]
        self.labels_list = torch.as_tensor(labels, dtype=torch.float32)
        self.n_samples = n_samples
        self.K = K
        self.noise_std = noise_std
        self.num_patients = len(bags)

    def __len__(self):
        return self.num_patients * self.K

    def __getitem__(self, idx):
        patient_idx = idx // self.K
        bag = self.bags[patient_idx]

        # sample cells
        N = bag.shape[0]
        if N > self.n_samples:
            idxs = torch.randint(0, N, (self.n_samples,))
            bag = bag[idxs]
        else:
            bag = bag.clone()

        # add Gaussian noise (on CPU)
        noise = self.noise_std * torch.randn_like(bag)
        bag = bag + noise

        label = self.labels_list[patient_idx]

        return bag, label

def collate_bags(batch):
    bags, labels = zip(*batch)
    bags = [b.float() for b in bags]                  # ensures tensor
    labels = torch.stack([torch.tensor(l) for l in labels]).float()
    return bags, labels




# ------------------------------------------------------------
# 3. Train function with class-balanced BCE
# ------------------------------------------------------------
def train_mil(model, dataloader, epochs=20, lr=1e-3, device="cuda"):

    labels_all = dataloader.dataset.labels_list
    pos = (labels_all == 1).sum()
    neg = (labels_all == 0).sum()
    pos_weight = (neg / (pos + 1e-6)).float().to(device)

    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    model.train()
    for epoch in range(epochs):
        loss_sum = 0

        for bags, labels in dataloader:
            labels = labels.float().to(device)

            optimizer.zero_grad()
            batch_loss = 0

            for bag, label in zip(bags, labels):
                bag = bag.to(device)
                logits, _ = model(bag)
                loss = criterion(logits.squeeze(), label)
                batch_loss += loss

            batch_loss /= len(bags)
            batch_loss.backward()
            optimizer.step()
            loss_sum += batch_loss.item()

        print(f"Epoch {epoch+1}/{epochs} LOSS {loss_sum/len(dataloader):.4f}")

    return model




# ------------------------------------------------------------
# 4. Predict on full bag
# ------------------------------------------------------------
def predict_patient(model, bag, device="cuda"):
    model.eval()
    with torch.no_grad():
        logits, att = model(bag.to(device))
        prob = torch.sigmoid(logits).item()
    return prob, att.cpu().numpy()


# ------------------------------------------------------------
# 5. LOO Cross-Validation
# ------------------------------------------------------------
def loo_cv(bags, labels_dict, n_samples=1000, epochs=20, K=5, lr=1e-3,
           att_dim=128, cls_dim=64, device="cuda"):

    results = []
    patient_ids = list(bags.keys())

    for left_out in patient_ids:
        print(f"\n=== LOO: Leaving out {left_out} ===")

        train_ids = [pid for pid in patient_ids if pid != left_out]

        train_bags = [bags[pid] for pid in train_ids]
        train_labels = torch.tensor([labels_dict[pid] for pid in train_ids])

        ds = MultiSampleRandomBags(train_bags, train_labels,
                                   n_samples=n_samples,
                                   K=K,
                                   noise_std=0.02)

        loader = DataLoader(ds, batch_size=4, shuffle=True, collate_fn=collate_bags)

        emb_dim = next(iter(bags.values())).shape[1]
        model = MILModel(emb_dim, att_dim=att_dim, cls_dim=cls_dim)
        model = model.to(device)   # <-- FIX HERE
        model = train_mil(model, loader, epochs=epochs, lr=lr, device=device)

        # evaluate
        prob, att = predict_patient(model, bags[left_out], device=device)
        results.append({"sample": left_out,
                        "prob": prob,
                        "true": labels_dict[left_out],
                        "attention": att})

    return results


# ------------------------------------------------------------
# 6. GROUP K-FOLD Cross-Validation
# ------------------------------------------------------------
def group_kfold_cv(bags, labels_dict, groups, n_splits=5,
                   n_samples=1000, epochs=20, K=5,
                   lr=1e-3, att_dim=128, cls_dim=64,
                   device="cuda"):

    """
    groups: array of group identifiers, length = #patients
    """

    patient_ids = list(bags.keys())
    X_dummy = np.zeros(len(patient_ids))    # sklearn requires X, doesn't matter
    y_dummy = np.zeros(len(patient_ids))

    gkf = GroupKFold(n_splits=n_splits)

    results = []

    for train_idx, test_idx in gkf.split(X_dummy, y_dummy, groups=groups):
        train_ids = [patient_ids[i] for i in train_idx]
        test_ids  = [patient_ids[i] for i in test_idx]

        print("\n=== Group Fold ===")
        print("Train:", train_ids)
        print("Test :", test_ids)

        train_bags = [bags[p] for p in train_ids]
        train_labels = torch.tensor([labels_dict[p] for p in train_ids])

        ds = MultiSampleRandomBags(train_bags, train_labels,
                                   n_samples=n_samples,
                                   K=K,
                                   noise_std=0.02)

        loader = DataLoader(ds, batch_size=4, shuffle=True, collate_fn=collate_bags)

        emb_dim = next(iter(bags.values())).shape[1]
        model = MILModel(emb_dim, att_dim=att_dim, cls_dim=cls_dim)
        model = train_mil(model, loader, epochs=epochs, lr=lr, device=device)

        # test fold
        for pid in test_ids:
            prob, att = predict_patient(model, bags[pid], device=device)
            results.append({"sample": pid,
                            "prob": prob,
                            "true": labels_dict[pid],
                            "attention": att})

    return results


# Load data

In [None]:
def group_small_clusters(
    df: pd.DataFrame,
    cluster_col: str,
    min_count: int = 1000,
    new_label: str = "small_clusters",
    output_col: str = None
) -> pd.DataFrame:
    """
    Groups small clusters in a DataFrame column into a single label.

    Parameters:
        df (pd.DataFrame): Input DataFrame containing cluster labels.
        cluster_col (str): Name of the column containing cluster labels (e.g., 'leiden').
        min_count (int): Minimum number of entries a cluster must have to avoid grouping.
        new_label (str): Label to assign to small clusters.
        output_col (str or None): Name of the new column to store grouped labels. 
                                  If None, defaults to '{cluster_col}_grouped'.

    Returns:
        pd.DataFrame: A copy of the DataFrame with a new column containing grouped cluster labels.
    """
    if cluster_col not in df.columns:
        raise ValueError(f"Column '{cluster_col}' not found in DataFrame.")

    output_col = output_col or f"{cluster_col}_grouped"
    cluster_counts = df[cluster_col].value_counts()
    small_clusters = cluster_counts[cluster_counts < min_count].index

    new_df = df.copy()
    new_df[output_col] = df[cluster_col].astype(str)
    new_df.loc[df[cluster_col].isin(small_clusters), output_col] = new_label

    return new_df[output_col]

In [None]:
import pandas as pd
import numpy as np
import pathlib as pl

In [None]:
clinical_info = pd.read_csv('../../../Broad_SpatialFoundation/VisiumHD-LUAD/clinical-info/full_clinical.csv', index_col=0)

In [None]:
base_dir = pl.Path('../../../Broad_SpatialFoundation/VisiumHD-LUAD-processed/')
sample_list = np.setdiff1d([f.stem for f in base_dir.iterdir()],['full_cohort','LIB-064888st1'])
sample_list = sample_list.astype(object)

In [None]:
sample_list

In [None]:
adata_obs = pd.read_parquet('../../../Broad_SpatialFoundation/notebooks/nsclc_adata_obs.parquet')

adata_obs['leiden_joint'] = group_small_clusters(
    adata_obs[['leiden']],
    cluster_col='leiden',
    min_count= 1000,
    new_label= "Other",
    output_col = None
)

malignant_niches = ['0','1','2','4','5','7','10','12','15','16',]

In [None]:
embeddings_df = {}
for sample in sample_list:
    embeddings_df[sample] = pd.read_parquet(base_dir / f'{sample}/embeddings/NicheFinder.parquet')
    embeddings_df[sample].columns = embeddings_df[sample].columns.astype(str) 
    embeddings_df[sample] = embeddings_df[sample][[f'{i}' for i in range(10)]]
    embeddings_df[sample].index =  embeddings_df[sample].index + '::' + sample
    embeddings_df[sample] = embeddings_df[sample].loc[embeddings_df[sample].index.intersection(adata_obs.index)]

In [None]:
df = adata_obs[['sample_id','leiden_joint']]
leiden = {}
for sample_id in df.sample_id.unique():
    leiden[sample_id] = df.loc[df.sample_id==sample_id][['leiden_joint']].astype(str)

sub_embeddings = {}
for sample in embeddings_df:
    sub_embeddings[sample] = embeddings_df[sample].loc[leiden[sample].index]
    sub_embeddings[sample] = sub_embeddings[sample].loc[leiden[sample]['leiden_joint'].isin(malignant_niches)]

In [None]:
for k,v in sub_embeddings.items():
    print(k)
    print(v.shape)

# Only with malignant niches

In [None]:
bags = {}
for sample_id, df in sub_embeddings.items():
    bags[sample_id] = torch.tensor(df.values.astype(float), dtype=torch.float32)


# enforce alignment (important!)
sample_ids = list(bags.keys())
bag_list = [bags[s] for s in sample_ids]
label_list = torch.tensor([labels_dict[s] for s in sample_ids], dtype=torch.float32)

print("Example bag shape:", bag_list[0].shape)

In [None]:
bags_scaled = patient_scale(bags)   # IMPORTANT

In [None]:
import pandas as pd
import numpy as np
from sklearn.metrics import roc_auc_score, accuracy_score, f1_score, balanced_accuracy_score

N_RUNS = 15
all_run_results = []
all_metrics = []

# NEW: store attention per sample per run
attention_store = {s: [] for s in bags_scaled.keys()}

tgt = 'pTNM T red'
targets = clinical_info[['Library',tgt]].set_index('Library')
targets = pd.get_dummies(targets).astype(int)
targets = targets.iloc[:,0].to_frame()
targets.columns = ["target"]
labels_dict = targets["target"].to_dict()

for run in range(1, N_RUNS + 1):
    print(f"\n\n##############################")
    print(f"###      RUN {run} / {N_RUNS}      ###")
    print(f"##############################\n")

    # run LOO
    results = loo_cv(
        bags_scaled,
        labels_dict,
        n_samples=1000,
        epochs=15,
        K=10,
        lr=1e-3,
        att_dim=128,
        cls_dim=64,
        device="cuda"
    )

    # extract predictions
    df = pd.DataFrame([
        {"sample": r["sample"], "true": r["true"], "prob": r["prob"]}
        for r in results
    ])
    df["run"] = run
    all_run_results.append(df)

    # NEW — store attention maps
    for r in results:
        sample = r["sample"]
        att = r["attention"]  # vector of length (#cells in patient)
        attention_store[sample].append(att)

    # compute metrics for this run
    auc = roc_auc_score(df["true"], df["prob"])
    acc = balanced_accuracy_score(df["true"], df["prob"] > 0.5)
    f1  = f1_score(df["true"], df["prob"] > 0.5)

    print(df)
    print(f"Run {run} — AUC: {auc:.3f}, BAC: {acc:.3f}, F1: {f1:.3f}")

    all_metrics.append({"run": run, "auc": auc, "acc": acc, "f1": f1})


# ------------------------------
# Combine per-run metrics
# ------------------------------
metrics_df = pd.DataFrame(all_metrics)
print("\n\n==========================================")
print("###  PER-RUN METRICS  ###")
print("==========================================")
print(metrics_df)

print("\n==========================================")
print("###  AGGREGATE PERFORMANCE  ###")
print("==========================================")

print(f"AUC: {metrics_df['auc'].mean():.3f} ± {metrics_df['auc'].std():.3f}")
print(f"BAC: {metrics_df['acc'].mean():.3f} ± {metrics_df['acc'].std():.3f}")
print(f"F1:  {metrics_df['f1'].mean():.3f} ± {metrics_df['f1'].std():.3f}")


# ------------------------------
# Averaged ensemble predictions
# ------------------------------
all_df = pd.concat(all_run_results)

ensemble_df = (
    all_df.groupby("sample")
          .agg(true=("true", "mean"),   # truth is identical across runs
               prob=("prob", "mean"))
          .reset_index()
)

ensemble_auc = roc_auc_score(ensemble_df["true"], ensemble_df["prob"])
ensemble_acc = balanced_accuracy_score(ensemble_df["true"], ensemble_df["prob"] > 0.5)
ensemble_f1  = f1_score(ensemble_df["true"], ensemble_df["prob"] > 0.5)

print("\n==========================================")
print("###  ENSEMBLE (AVERAGED-PROB) PERFORMANCE  ###")
print("==========================================")
print(ensemble_df)

print(f"Ensemble AUC: {ensemble_auc:.3f}")
print(f"Ensemble BAC: {ensemble_acc:.3f}")
print(f"Ensemble F1:  {ensemble_f1:.3f}")


# ------------------------------
# NEW: Compute averaged attention maps
# -------------------------------
avg_attention = {}

for sample, att_list in attention_store.items():
    # each entry is a vector of size (#cells in sample)
    A = np.stack(att_list, axis=0)     # shape = [num_runs, num_cells]
    avg_attention[sample] = A.mean(axis=0)

# Now avg_attention[sample] is the consensus attention per cell
print("\n==========================================")
print("###  STORED & AVERAGED ATTENTION MAPS ###")
print("==========================================")
for sample in avg_attention:
    print(sample, "attention shape:", avg_attention[sample].shape)


In [None]:
ensemble_auc = roc_auc_score(ensemble_df["true"], ensemble_df["prob"])
ensemble_acc = balanced_accuracy_score(ensemble_df["true"], ensemble_df["prob"] > 0.5)
ensemble_f1  = f1_score(ensemble_df["true"], ensemble_df["prob"] > 0.5)

print("\n==========================================")
print("###  ENSEMBLE (AVERAGED-PROB) PERFORMANCE  ###")
print("==========================================")
print(ensemble_df)

print(f"Ensemble AUC: {ensemble_auc:.3f}")
print(f"Ensemble BAC: {ensemble_acc:.3f}")
print(f"Ensemble F1:  {ensemble_f1:.3f}")

In [None]:
ensemble_df.to_csv('ensemble_pred_pTNMred.csv')

In [None]:
for spl in avg_attention:
    pd.DataFrame(avg_attention[spl], index=sub_embeddings[spl].index, columns=['Avg attention']).to_parquet(f'{spl}_avg_attention.parquet')

# Analyze results

In [None]:
import matplotlib
matplotlib.rcParams['svg.fonttype'] = 'none'

In [None]:
ensemble_df = pd.read_csv('ensemble_pred_pTNMred.csv', index_col=0)

avg_attention = {}
for spl in sample_list:
    avg_attention[spl]= pd.read_parquet(f'{spl}_avg_attention_pTNM_T_red.parquet')

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import roc_curve, auc, confusion_matrix
import numpy as np

# ================================
# Set Nature Genetics–like styling
# ================================

plt.rcParams.update({
    "font.size": 14,
    "axes.labelsize": 16,
    "axes.titlesize": 18,
    "axes.linewidth": 1.2,
    "xtick.labelsize": 13,
    "ytick.labelsize": 13,
    "legend.fontsize": 13,
    "figure.dpi": 120,
})

# Extract values
y_true = ensemble_df["true"].values
y_prob = ensemble_df["prob"].values
y_pred = (ensemble_df["prob"] > 0.5).astype(int)

# =======================================
# PLOT 1 — ROC CURVE (Nature Genetics look)
# =======================================
fpr, tpr, _ = roc_curve(y_true, y_prob)
roc_auc = auc(fpr, tpr)

fig, ax = plt.subplots(figsize=(4, 3))

ax.plot(fpr, tpr, color="#2c7fb8", lw=3, label=f"AUC = {roc_auc:.3f}")
ax.plot([0, 1], [0, 1], color="black", lw=1.2, linestyle="--")

ax.set_xlabel("False Positive Rate")
ax.set_ylabel("True Positive Rate")
ax.set_title("ROC Curve")
ax.legend(loc="lower right", frameon=False)

sns.despine(trim=True)
plt.tight_layout()
fig.savefig('../../../SpatialFusion/results/figures_Fig6/roc_auc_abmil_pTNM_T_red.svg')
plt.show()

# ==========================================
# PLOT 2 — CONFUSION MATRIX (publication quality)
# ==========================================
cm = confusion_matrix(y_true, y_pred)
cm_norm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis]

fig, ax = plt.subplots(1,1, figsize=(3,2))

# Raw counts
sns.heatmap(
    cm, annot=True, fmt="d", cmap="Blues",
    cbar=False, square=True, ax=ax,
    annot_kws={"size": 16}
)
ax.set_title("Confusion Matrix")
ax.set_xlabel("Predicted")
ax.set_ylabel("True")

fig.savefig('../../../SpatialFusion/results/figures_Fig6/confusion_abmil_pTNM_T_red.svg')
plt.show()


In [None]:
import pandas as pd
import numpy as np

all_cells = []

for sample in avg_attention:
    att = avg_attention[sample].values.ravel()                        # vector (#cells,)
    idx = sub_embeddings[sample].index             # same order as bag                    # match your adata_obs index

    df = pd.DataFrame({
        "attention": att,
        "sample": sample,
        "cell_id": idx,
        "global_id": idx
    })
    df = df.set_index("global_id")
    # bring in metadata (cluster, subtype, coordinates,…)
    df = pd.concat([df, adata_obs.loc[df.index]],axis=1)

    all_cells.append(df)

all_cells_df = pd.concat(all_cells)
print("Combined shape:", all_cells_df.shape)


In [None]:
all_cells_df["att_norm"] = (
    all_cells_df.groupby("sample")["attention"]
    .transform(lambda x: x / x.sum())
)


In [None]:
QUANTILE = 0.95   # top 5% attention cells

top_cells = (
    all_cells_df.groupby("sample")
    .apply(lambda df: df[df.attention >= df.attention.quantile(QUANTILE)])
    .reset_index(drop=True)
)

print("Top cells shape:", top_cells.shape)


In [None]:
baseline = {}

for sample in avg_attention:
    df_all = all_cells_df[all_cells_df["sample"] == sample]
    
    baseline[sample] = {
        "leiden": df_all["leiden_joint"].value_counts(normalize=True),
        "subtype": df_all["refined_cellsubtypes"].value_counts(normalize=True),
    }


In [None]:
from scipy.stats import fisher_exact, chi2_contingency
import numpy as np
import pandas as pd
from statsmodels.stats.multitest import multipletests

results = []

TOP_Q = 0.95  # e.g., top 5% attention

for sample in avg_attention:
    df_all = all_cells_df[all_cells_df["sample"] == sample]

    # Identify top-attention cells
    cutoff = df_all["attention"].quantile(TOP_Q)
    df_top = df_all[df_all["attention"] >= cutoff]

    # Baseline & observed counts
    total_cells = len(df_all)
    top_cells = len(df_top)

    # Process Leiden clusters
    for cluster in df_all["leiden_joint"].unique():

        k_obs = (df_top["leiden_joint"] == cluster).sum()            # in top
        k_exp = (df_all["leiden_joint"] == cluster).sum()            # baseline
        k_not_obs = top_cells - k_obs                                # not cluster in top
        k_not_exp = total_cells - k_exp                              # not cluster in sample

        # 2×2 table
        table = np.array([
            [k_obs,    k_not_obs],
            [k_exp-k_obs, k_not_exp-k_not_obs]
        ])

        # Fisher exact test
        try:
            odds, p = fisher_exact(table)
        except:
            odds, p = np.nan, 1.0

        # Log2 fold-change enrichment
        frac_top = k_obs / top_cells if top_cells > 0 else 0
        frac_all = k_exp / total_cells
        log2fc = np.log2((frac_top + 1e-6) / (frac_all + 1e-6))

        # Store
        results.append({
            "sample": sample,
            "cluster": cluster,
            "k_obs": k_obs,
            "k_exp": k_exp,
            "top_frac": frac_top,
            "baseline_frac": frac_all,
            "log2FC": log2fc,
            "p_fisher": p,
            "odds_ratio": odds
        })

results_df = pd.DataFrame(results)

results_df["q_fisher"] = multipletests(results_df["p_fisher"], method="fdr_bh")[1]

results_df.head()


In [None]:
results_df.sort_values('odds_ratio',ascending=False).head(20)

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

def plot_cluster_enrichment(results_df, centroids_df, cluster_palette, savefig=None):
    """
    results_df: per-sample results
    centroids_df: aggregated cluster centers
    cluster_palette: dict {cluster: (r,g,b)} from palettes["leiden_joint"]
    """

    # ------------------------------
    # Safely compute -log10(q)
    # Replace q=0 with the minimum positive value
    # ------------------------------
    q = results_df["q_fisher"].replace(0, results_df["q_fisher"][results_df["q_fisher"] > 0].min())
    results_df = results_df.assign(logp = -np.log10(q))

    # ------------------------------
    # Create figure
    # ------------------------------
    fig, ax = plt.subplots(figsize=(5,4))

    # Consistent ordering and palette
    unique_clusters = sorted(results_df["cluster"].unique())
    palette_used = {c: cluster_palette[str(c)] for c in unique_clusters}

    # ------------------------------
    # 1. Individual sample points
    # ------------------------------
    sns.scatterplot(
        ax=ax,
        data=results_df,
        x="log2FC",
        y="logp",
        hue="cluster",
        palette=palette_used,
        s=10,
        alpha=0.7,
        linewidth=0
    )

    # ------------------------------
    # 2. Centroids (bigger, diamond marker)
    # ------------------------------
    sns.scatterplot(
        ax=ax,
        data=centroids_df,
        x="log2FC_median",
        y="logp_median",
        hue="cluster",
        palette=palette_used,
        s=30,
        marker="D",
        edgecolor="black",
        linewidth=1.3,
        alpha=0.9,
        legend=False   # avoid duplicate legends
    )

    # ------------------------------
    # Reference lines
    # ------------------------------
    ax.axvline(0, color="black", linestyle="--", linewidth=1)
    ax.axhline(1, color="black", linestyle="--", linewidth=1)

    # ------------------------------
    # Labels & title
    # ------------------------------
    ax.set_xlabel("log2 Fold-Change (top-attention vs background)", fontsize=14)
    ax.set_ylabel("-log10(FDR)", fontsize=14)
    ax.set_title("Niche Enrichment in High-Attention Cells", fontsize=16)

    # ------------------------------
    # Legend outside, no frame
    # ------------------------------
    legend = ax.legend(
        title="Niche",
        bbox_to_anchor=(1.02, 1),
        loc="upper left",
        frameon=False
    )

    # Increase dot size in legend
    # Use modern handle list: legend.legend_handles
    for h in legend.legend_handles:
        try:
            h.set_sizes([70])
        except Exception:
            pass  # some handles may not support set_sizes

    plt.tight_layout()
    if savefig:
        fig.savefig(savefig)
    plt.show()

    return fig, ax


In [None]:
import json

with open("../../../Broad_SpatialFoundation/notebooks/palettes_NSCLC_Novartis.json", "r") as f:
    palettes = json.load(f)

In [None]:
plot_cluster_enrichment(
    results_df=results_df,
    centroids_df=centroids,
    cluster_palette=palettes["leiden_joint"],
    savefig='../../../SpatialFusion/results/figures_Fig6/volcano_attention_leiden.svg',
)


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# ===============================
# Global Nature Genetics styling
# ===============================

plt.rcParams.update({
    "font.size": 14,
    "axes.labelsize": 14,
    "axes.titlesize": 14,
    "axes.linewidth": 1.2,
    "xtick.labelsize": 12,
    "ytick.labelsize": 12,
    "legend.fontsize": 14,
    "figure.dpi": 200,
})

# A nicer red colormap
from matplotlib.colors import LinearSegmentedColormap
NG_REDS = LinearSegmentedColormap.from_list(
    "NG_Reds",
    ["#fee5d9", "#fcae91", "#fb6a4a", "#de2d26", "#a50f15"]
)

def plot_attention_map_NG(sample, savefig=None):
    df = all_cells_df[all_cells_df["sample"] == sample]

    fig, ax = plt.subplots(figsize=(3.1, 3))

    # Scatter: Nature style = slightly larger points, calm alpha
    sc = ax.scatter(
        df["X"], df["Y"],
        c=df["attention"],
        cmap=NG_REDS,
        s=4,
        alpha=0.85,
        edgecolors="none",
        rasterized=True,
    )

    # Title
    ax.set_title(f"Attention map — {sample}", pad=12)

    # Reverse Y for Visium
    ax.invert_yaxis()
    ax.set_aspect("equal", adjustable="box")

    # Remove spines
    sns.despine(ax=ax)

    # Remove axis ticks and labels for a clean NG look
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xlabel("")
    ax.set_ylabel("")

    # Add a slim colorbar
    cbar = fig.colorbar(sc, ax=ax, fraction=0.046, pad=0.04)
    cbar.set_label("Attention", rotation=270, labelpad=15)
    cbar.outline.set_linewidth(0.8)

    plt.tight_layout()
    if savefig is not None:
        fig.savefig(savefig, bbox_inches='tight')
    plt.show()


# ===============================
# Plot all samples
# ===============================
for spl in all_cells_df['sample'].unique():
    print(spl)
    plot_attention_map_NG(spl, savefig=f'../../../SpatialFusion/results/figures_Fig6/{spl}_attn_map.svg')


# Dissociated

Here we try to do ABMIL with the gene expression to compare

In [None]:
import scanpy as sc

In [None]:
base_dir = pl.Path('../../../Broad_SpatialFoundation/VisiumHD-LUAD-processed/')
sample_list = np.setdiff1d([f.stem for f in base_dir.iterdir()],['full_cohort','LIB-064888st1'])
sample_list

In [None]:
adatas = {}
for sample in tqdm(sample_list):
    adata = sc.read_h5ad(base_dir / sample / 'adata.h5ad')
    adata.obs_names = adata.obs_names + '::' + sample 
    adata.obs['sample_id'] = sample
    
    common_idx = adata.obs_names.intersection(embeddings_df[sample].index)
    adata = adata[common_idx].copy()
    adata = adata[adata.obs.celltypes != 'Noise'].copy()
    adatas[sample] = adata

In [None]:
gex_embeddings = {}
for sample in tqdm(sample_list):
    sc.pp.normalize_total(adatas[sample], target_sum=10000)
    sc.pp.log1p(adatas[sample])
    sc.tl.pca(adatas[sample])
    gex_embeddings[sample] = adatas[sample].obsm['X_pca'].copy()

In [None]:
bags = {}

for sample_id, df in gex_embeddings.items():
    bags[sample_id] = torch.tensor(df.astype(float), dtype=torch.float32)


# enforce alignment (important!)
sample_ids = list(bags.keys())
bag_list = [bags[s] for s in sample_ids]
label_list = torch.tensor([labels_dict[s] for s in sample_ids], dtype=torch.float32)

print("Example bag shape:", bag_list[0].shape)

In [None]:
bags_scaled = patient_scale(bags)   # IMPORTANT

In [None]:
import pandas as pd
import numpy as np
from sklearn.metrics import roc_auc_score, accuracy_score, f1_score, balanced_accuracy_score

N_RUNS = 15
all_run_results = []
all_metrics = []

# NEW: store attention per sample per run
attention_store = {s: [] for s in bags_scaled.keys()}

tgt = 'pTNM T red'
targets = clinical_info[['Library',tgt]].set_index('Library')
targets = pd.get_dummies(targets).astype(int)
targets = targets.iloc[:,0].to_frame()
targets.columns = ["target"]
labels_dict = targets["target"].to_dict()

for run in range(1, N_RUNS + 1):
    print(f"\n\n##############################")
    print(f"###      RUN {run} / {N_RUNS}      ###")
    print(f"##############################\n")

    # run LOO
    results = loo_cv(
        bags_scaled,
        labels_dict,
        n_samples=1000,
        epochs=15,
        K=10,
        lr=1e-3,
        att_dim=128,
        cls_dim=64,
        device="cuda"
    )

    # extract predictions
    df = pd.DataFrame([
        {"sample": r["sample"], "true": r["true"], "prob": r["prob"]}
        for r in results
    ])
    df["run"] = run
    all_run_results.append(df)

    # NEW — store attention maps
    for r in results:
        sample = r["sample"]
        att = r["attention"]  # vector of length (#cells in patient)
        attention_store[sample].append(att)

    # compute metrics for this run
    auc = roc_auc_score(df["true"], df["prob"])
    acc = balanced_accuracy_score(df["true"], df["prob"] > 0.5)
    f1  = f1_score(df["true"], df["prob"] > 0.5)

    print(df)
    print(f"Run {run} — AUC: {auc:.3f}, BAC: {acc:.3f}, F1: {f1:.3f}")

    all_metrics.append({"run": run, "auc": auc, "acc": acc, "f1": f1})


# ------------------------------
# Combine per-run metrics
# ------------------------------
metrics_df = pd.DataFrame(all_metrics)
print("\n\n==========================================")
print("###  PER-RUN METRICS  ###")
print("==========================================")
print(metrics_df)

print("\n==========================================")
print("###  AGGREGATE PERFORMANCE  ###")
print("==========================================")

print(f"AUC: {metrics_df['auc'].mean():.3f} ± {metrics_df['auc'].std():.3f}")
print(f"BAC: {metrics_df['acc'].mean():.3f} ± {metrics_df['acc'].std():.3f}")
print(f"F1:  {metrics_df['f1'].mean():.3f} ± {metrics_df['f1'].std():.3f}")


# ------------------------------
# Averaged ensemble predictions
# ------------------------------
all_df = pd.concat(all_run_results)

ensemble_df = (
    all_df.groupby("sample")
          .agg(true=("true", "mean"),   # truth is identical across runs
               prob=("prob", "mean"))
          .reset_index()
)

ensemble_auc = roc_auc_score(ensemble_df["true"], ensemble_df["prob"])
ensemble_acc = balanced_accuracy_score(ensemble_df["true"], ensemble_df["prob"] > 0.5)
ensemble_f1  = f1_score(ensemble_df["true"], ensemble_df["prob"] > 0.5)

print("\n==========================================")
print("###  ENSEMBLE (AVERAGED-PROB) PERFORMANCE  ###")
print("==========================================")
print(ensemble_df)

print(f"Ensemble AUC: {ensemble_auc:.3f}")
print(f"Ensemble BAC: {ensemble_acc:.3f}")
print(f"Ensemble F1:  {ensemble_f1:.3f}")


# ------------------------------
# NEW: Compute averaged attention maps
# -------------------------------
avg_attention = {}

for sample, att_list in attention_store.items():
    # each entry is a vector of size (#cells in sample)
    A = np.stack(att_list, axis=0)     # shape = [num_runs, num_cells]
    avg_attention[sample] = A.mean(axis=0)

# Now avg_attention[sample] is the consensus attention per cell
print("\n==========================================")
print("###  STORED & AVERAGED ATTENTION MAPS ###")
print("==========================================")
for sample in avg_attention:
    print(sample, "attention shape:", avg_attention[sample].shape)


In [None]:
ensemble_auc = roc_auc_score(ensemble_df["true"], ensemble_df["prob"])
ensemble_acc = balanced_accuracy_score(ensemble_df["true"], ensemble_df["prob"] > 0.5)
ensemble_f1  = f1_score(ensemble_df["true"], ensemble_df["prob"] > 0.5)

print("\n==========================================")
print("###  ENSEMBLE (AVERAGED-PROB) PERFORMANCE  ###")
print("==========================================")
print(ensemble_df)

print(f"Ensemble AUC: {ensemble_auc:.3f}")
print(f"Ensemble BAC: {ensemble_acc:.3f}")
print(f"Ensemble F1:  {ensemble_f1:.3f}")

In [None]:
ensemble_df.to_csv('ensemble_pred_GEX.csv')

In [None]:
for spl in avg_attention:
    pd.DataFrame(avg_attention[spl], index=adatas[spl].obs_names, columns=['Avg attention']).to_parquet(f'{spl}_avg_attention_GEX.parquet')