# Step 4: Processing GASTON Neural Network Output

## What is this step doing?

After Step 3, we have **30 independently trained bottleneck autoencoders** per slice, each producing a scalar **isodepth** $d_i$ for every spot $i$. This step turns those raw model files into interpretable spatial domain labels, and evaluates quality against expert-annotated ground truth.

There are three sub-tasks here:

1. **Model selection** — pick the best of the 30 restarts
2. **Isodepth extraction + domain segmentation** — convert the scalar isodepth into discrete layer labels using breakpoint detection (dynamic programming)
3. **Evaluation** — compare predicted domains to the DLPFC ground-truth annotations using ARI and NMI

---

## Sub-task 1: Model Selection

**Strategy A — pick by reconstruction loss** (`process_NN_output.process_files`)  
The GASTON library loads all 30 `min_loss.txt` files and returns the model with the lowest final MSE. This is the standard approach: lower reconstruction loss = better fit to the data.

**Strategy B — pick by ARI against ground truth** (used later in this notebook)  
Since we have expert annotations, we can also iterate over all 30 reps, extract domains for each, and pick the rep with the highest ARI. This is 'oracle' selection — it uses labels we wouldn't have in a real unsupervised setting, but it gives an upper bound and is useful for ablations.

---

## Sub-task 2: Isodepth → Domain Labels

The decoder $h_\psi : \mathbb{R} \to \mathbb{R}^{14}$ is a **piecewise-linear curve** in embedding space (because it's a ReLU MLP from a 1D input). The breakpoints of this curve correspond to **transitions between tissue layers**.

`dp_related.get_isodepth_labels(model, A, S, num_layers)` runs **dynamic programming** to find the optimal `num_layers - 1` breakpoints along the isodepth axis that minimize within-segment variance. This is analogous to 1D k-means with an optimal merge criterion.

The output is:
- `gaston_isodepth` — the scalar $d_i$ for each spot (a continuous value)
- `gaston_labels` — integer domain label 0..K-1 for each spot (discrete, from breakpoint detection)

**Design choice:** `num_layers = 7` is hardcoded here because we know the DLPFC has 6 cortical layers + white matter. In a real unsupervised setting, you would use `model_selection.plot_ll_curve` to estimate the number of domains from a likelihood curve (like an elbow/BIC plot for k-means).

---

## Sub-task 3: Evaluation — ARI and NMI

**ARI (Adjusted Rand Index)**: Measures agreement between two clusterings, corrected for chance. Range: [-1, 1], higher is better. ARI = 1 means perfect agreement. Random clusterings give ARI ≈ 0.

**NMI (Normalized Mutual Information)**: Measures shared information between two clusterings, normalized to [0, 1]. Higher is better. NMI = 1 means perfect agreement.

Both metrics are **permutation-invariant** — they don't care which integer label maps to which layer, only whether the partitioning agrees.

---

## Could we replace the whole GASTON pipeline with PCA / NMF / scVI?

Yes — and running these as baselines is exactly the kind of ablation that makes a strong project. The table below compares GASTON against what you would get if you used standard ML tools from EECS 545 for spatial domain detection:

| Method | Embedding (replaces Step 2) | Clustering (replaces Steps 3+4) | Key limitation vs. GASTON |
|---|---|---|---|
| **PCA + Leiden** | PCA on log-normalized counts | k-NN graph → Leiden clustering | PCA assumes Gaussian noise; Leiden ignores spatial coordinates |
| **NMF + k-means** | NMF gives non-negative topic scores per spot | k-means on topic vectors | No spatial constraint; topics may not align with layer ordering |
| **scVI + Leiden** | VAE with Negative Binomial likelihood | Leiden on latent embedding | scVI ignores (x,y) coordinates entirely |
| **STAGATE** | Graph attention autoencoder on spatial neighbors | Leiden on latent | No continuous isodepth; layer ordering not explicitly learned |
| **GASTON (ours)** | GLM-PCA (Step 2) | Bottleneck autoencoder on (x,y) → isodepth → DP segmentation | Assumes 1D dominant structure |

**What makes GASTON different** from all of the above is the **bottleneck constraint**: it forces the model to explain all expression variation with a *single* spatial axis anchored to (x,y) coordinates. This is a strong inductive bias — it works well when the tissue has a dominant 1D structure (like cortical layers running top-to-bottom), and would struggle on tissues with more complex 2D spatial organization.

Any of the alternatives above are valid baselines. Comparing GASTON ARI against PCA+Leiden ARI is a clean one-script experiment to add to the paper.

---

## Why this step matters for the C-GASTON project

**This step establishes the baseline ARI that C-GASTON has to beat.**

The whole point of C-GASTON is to ask: *does adding H&E image features to the encoder improve spatial domain detection?*

The pipeline comparison:

```
GASTON (baseline, Steps 1-4):
  Gene counts → GLM-PCA → A (N×14)
  Spatial coords (x,y) ──────────→ Encoder φ_θ: R² → R  →  isodepth d_i
                                    Decoder h_ψ: R  → R^14 →  reconstructed A
  → domain labels (DP)  →  ARI ← THIS NOTEBOOK measures this number

C-GASTON (proposed):
  Gene counts → GLM-PCA → A (N×14)
  Spatial coords (x,y)  ─┐
  H&E patch features (f) ─┴→ Encoder φ_θ: R^(2+d) → R  →  isodepth d_i
                               Decoder h_ψ: R        → R^14 →  reconstructed A
  → domain labels (DP)  →  ARI ← same evaluation; expect this number to go up
```

The **only change** between GASTON and C-GASTON is the encoder input: `(x, y)` → `(x, y, f)` where `f` is a feature vector from the H&E image patch at that spot's location.

**Why would H&E features help?**  
H&E contains morphological information invisible to gene expression alone — cell density, cell size, nuclear staining intensity, tissue architecture. Different DLPFC layers look visually distinct (e.g., Layer 4 has densely packed granule cells; white matter is pale with axon bundles). If the encoder can leverage this signal, it should learn an isodepth that more cleanly separates layers, especially at layer boundaries where gene expression transitions are gradual.

**How to extract H&E patch features (the vision branch):**
- For each spot $i$, crop a patch centered at its `(imagerow, imagecol)` pixel location in the high-resolution H&E image
- Feed each patch through a pretrained vision encoder: ResNet-50, ViT-B/16, or a pathology-specific model (UNI, CONCH, PLIP)
- Output: feature vector $\mathbf{f}_i \in \mathbb{R}^d$ (e.g., $d = 512$ for ResNet, $d = 768$ for ViT-B)
- Concatenate: encoder input = $[\,x_i,\; y_i,\; \mathbf{f}_i\,] \in \mathbb{R}^{2+d}$

**Optional: contrastive regularization (InfoNCE)**  
Beyond simply concatenating features, C-GASTON can add a contrastive loss to explicitly align the isodepth with H&E similarity:

$$\mathcal{L} = \underbrace{\frac{1}{N}\sum_i \|\mathbf{a}_i - h_\psi(d_i)\|^2}_{\text{reconstruction (GASTON)}} + \lambda \cdot \underbrace{\mathcal{L}_{\text{InfoNCE}}(d_i,\, \mathbf{f}_i)}_{\text{align isodepth with H\&E}}$$

The contrastive term pulls together spots whose H&E patches look similar (should share the same layer) and pushes apart spots with visually different patches. This is analogous to contrastive self-supervised learning from EECS 545, applied here to enforce consistency between the visual and genomic signals.

**Step 4 is unchanged in C-GASTON** — once we have the trained model, the DP segmentation and ARI/NMI evaluation are identical. The only question is whether the ARI numbers go up.

In [None]:
# =============================================================================
# Step 4: Processing GASTON Neural Network Output
# =============================================================================
# EECS 545 Project: C-GASTON
# Reference: https://gaston-tutorial.readthedocs.io/

import os
import numpy as np
import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score

from gaston import process_NN_output, dp_related, cluster_plotting
from gaston import model_selection, isodepth_scaling

plt.rcParams['figure.dpi'] = 150
plt.rcParams['savefig.dpi'] = 150

print("Libraries imported.")

In [None]:
# =============================================================================
# Paths
# =============================================================================

BASE_DIR = '/home/siruilf/A_new_dataset_for_gaston'

# Inputs from previous steps
GLMPCA_DIR = f'{BASE_DIR}/2.GLM_PC/glmpca_results'
NN_DIR     = f'{BASE_DIR}/3.Training_Gaston__NN/nn_results'

# Output directory
OUTPUT_DIR = f'{BASE_DIR}/4.Process_NN_Output/results'
os.makedirs(OUTPUT_DIR, exist_ok=True)

print(f"GLM-PCA dir : {GLMPCA_DIR}")
print(f"NN dir      : {NN_DIR}")
print(f"Output dir  : {OUTPUT_DIR}")

In [None]:
# =============================================================================
# Slices to process
# =============================================================================

slices_to_process = ['151507', '151508', '151673', '151674']

# DLPFC cortical layer colors (for reference — gaston uses color_palette internally)
LAYER_COLORS = {
    'L1': '#E41A1C',
    'L2': '#377EB8',
    'L3': '#4DAF4A',
    'L4': '#984EA3',
    'L5': '#FF7F00',
    'L6': '#FFFF33',
    'WM': '#A65628',
}

print("Checking for trained NN results...")
for slice_id in slices_to_process:
    nn_path = f'{NN_DIR}/{slice_id}'
    exists = os.path.exists(nn_path)
    print(f"  {slice_id}: {'ready' if exists else 'MISSING — run Step 3 first'}")

In [None]:
# =============================================================================
# Processing function
# =============================================================================

def process_slice(slice_id, nn_dir, glmpca_dir, output_dir, num_layers=7):
    """
    Process the NN output for a single tissue slice.

    Pipeline
    --------
    1. Load best model  — process_NN_output.process_files picks the rep with
       the lowest reconstruction MSE across all 30 restarts.
    2. Model selection plot  — log-likelihood curve to verify that 7 domains
       is a reasonable choice (like an elbow plot for k-means).
    3. Isodepth + domain labels  — dp_related.get_isodepth_labels runs DP
       breakpoint detection on the 1D isodepth to segment into `num_layers`
       discrete domains.
    4. Visualize isodepth map  — continuous scalar field over (x, y).
    5. Visualize spatial domains  — discrete domain labels over (x, y).
    6. Save outputs  — isodepth.npy, labels.npy for downstream use.

    Parameters
    ----------
    slice_id   : str  e.g. '151507'
    nn_dir     : str  directory with trained model checkpoints
    glmpca_dir : str  directory with GLM-PCA outputs (not used here, kept for symmetry)
    output_dir : str  where to save results
    num_layers : int  number of domains (7 for DLPFC: L1-L6 + WM)
    """
    print(f"\n{'='*60}")
    print(f"Slice: {slice_id}")
    print(f"{'='*60}")

    slice_output_dir = f'{output_dir}/{slice_id}'
    os.makedirs(slice_output_dir, exist_ok=True)

    # --- Step 1: Load best model (selected by lowest MSE loss) ---
    print("\n[1/6] Loading best model (lowest MSE across 30 restarts)...")
    nn_slice_dir = f'{nn_dir}/{slice_id}'
    gaston_model, A, S = process_NN_output.process_files(nn_slice_dir)
    print(f"      A shape: {A.shape}  (N x K embeddings)")
    print(f"      S shape: {S.shape}  (N x 2 coordinates)")

    # --- Step 2: Model selection plot (log-likelihood curve) ---
    # This is analogous to a BIC/elbow plot for k-means: we try different
    # numbers of domains and look for where the curve bends.
    print("\n[2/6] Plotting log-likelihood curve (model selection)...")
    fig_path = f'{slice_output_dir}/model_selection.png'
    plt.figure()
    model_selection.plot_ll_curve(gaston_model, A, S, max_domain_num=10, start_from=2, num_buckets=100)
    plt.savefig(fig_path, dpi=150, bbox_inches='tight')
    plt.show()
    plt.close()
    print(f"      Saved: {fig_path}")

    # --- Step 3: Isodepth + domain labels via DP breakpoint detection ---
    # get_isodepth_labels runs 1D dynamic programming to find the num_layers-1
    # breakpoints along the isodepth axis that minimize within-segment variance.
    print(f"\n[3/6] Computing isodepth + domain labels (num_layers={num_layers})...")
    gaston_isodepth, gaston_labels = dp_related.get_isodepth_labels(gaston_model, A, S, num_layers)
    print(f"      isodepth shape : {gaston_isodepth.shape}  (continuous scalar per spot)")
    print(f"      labels shape   : {gaston_labels.shape}   (discrete int per spot)")
    print(f"      unique labels  : {np.unique(gaston_labels)}")

    # --- Step 4: Isodepth map ---
    print("\n[4/6] Plotting isodepth map with streamlines...")
    fig_path = f'{slice_output_dir}/isodepth_map.png'
    cluster_plotting.plot_isodepth(gaston_isodepth, S, gaston_model,
                                   figsize=(7, 6), streamlines=True, cmap='Reds')
    plt.savefig(fig_path, dpi=150, bbox_inches='tight')
    plt.show()
    plt.close()
    print(f"      Saved: {fig_path}")

    # --- Step 5: Spatial domain map ---
    print("\n[5/6] Plotting spatial domains...")
    fig_path = f'{slice_output_dir}/spatial_domains.png'
    cluster_plotting.plot_clusters(gaston_labels, S, figsize=(6, 6),
                                   color_palette=plt.cm.tab10, s=10, lgd=True)
    plt.savefig(fig_path, dpi=150, bbox_inches='tight')
    plt.show()
    plt.close()
    print(f"      Saved: {fig_path}")

    # --- Step 6: Save outputs ---
    print("\n[6/6] Saving outputs...")
    np.save(f'{slice_output_dir}/gaston_isodepth.npy', gaston_isodepth)
    np.save(f'{slice_output_dir}/gaston_labels.npy', gaston_labels)
    print(f"      Saved to: {slice_output_dir}")

    return {
        'gaston_model': gaston_model,
        'A': A,
        'S': S,
        'isodepth': gaston_isodepth,
        'labels': gaston_labels
    }


print("Function defined.")

In [None]:
# =============================================================================
# Process all slices (using best-by-loss model selection)
# =============================================================================
# num_layers = 7: the DLPFC has 6 cortical layers (L1-L6) + white matter (WM)

num_layers = 7

results = {}

for slice_id in slices_to_process:
    results[slice_id] = process_slice(
        slice_id   = slice_id,
        nn_dir     = NN_DIR,
        glmpca_dir = GLMPCA_DIR,
        output_dir = OUTPUT_DIR,
        num_layers = num_layers
    )

print(f"\n{'='*60}")
print("All slices processed.")
print(f"{'='*60}")

In [None]:
# =============================================================================
# Compare against ground truth — compute ARI and NMI
# =============================================================================
# ARI (Adjusted Rand Index): agreement between two clusterings, corrected for
# chance. Range [-1, 1]. Random clustering gives ARI ≈ 0.
#
# NMI (Normalized Mutual Information): shared information between two
# clusterings. Range [0, 1]. Higher is better.
#
# Both metrics are permutation-invariant: they don't require label alignment.

SAMPLE1_DATA_DIR = f'{BASE_DIR}/ST_Data/DLPFC_sample1/original_data'
SAMPLE3_DATA_DIR = f'{BASE_DIR}/ST_Data/DLPFC_sample3/original_data'


def compare_with_ground_truth(slice_id, results, data_dir, output_dir):
    """
    Compute ARI and NMI between GASTON predictions and expert annotations.
    Saves side-by-side plots of ground truth vs. prediction.
    """
    print(f"\nEvaluating {slice_id}...")

    # Load expert annotations
    adata = sc.read_h5ad(f'{data_dir}/{slice_id}.h5ad')
    gt_labels = adata.obs['original_domain'].astype(str).values

    pred_labels = results[slice_id]['labels']
    S           = results[slice_id]['S']

    # Map string labels to integers for metric computation
    gt_label_map = {'L1': 0, 'L2': 1, 'L3': 2, 'L4': 3, 'L5': 4, 'L6': 5, 'WM': 6}
    gt_labels_numeric = np.array([gt_label_map.get(l, -1) for l in gt_labels])

    ari = adjusted_rand_score(gt_labels_numeric, pred_labels)
    nmi = normalized_mutual_info_score(gt_labels_numeric, pred_labels)

    print(f"  ARI: {ari:.4f}")
    print(f"  NMI: {nmi:.4f}")

    # Plot ground truth
    fig_path = f'{output_dir}/{slice_id}/ground_truth.png'
    cluster_plotting.plot_clusters(gt_labels_numeric, S, figsize=(6, 6),
                                   color_palette=plt.cm.tab10, s=10, lgd=True)
    plt.title(f'{slice_id} Ground Truth', fontsize=14, fontweight='bold')
    plt.savefig(fig_path, dpi=150, bbox_inches='tight')
    plt.show()
    plt.close()

    # Plot GASTON prediction
    fig_path = f'{output_dir}/{slice_id}/gaston_prediction.png'
    cluster_plotting.plot_clusters(pred_labels, S, figsize=(6, 6),
                                   color_palette=plt.cm.tab10, s=10, lgd=True)
    plt.title(f'{slice_id} GASTON (ARI={ari:.4f}, NMI={nmi:.4f})', fontsize=14, fontweight='bold')
    plt.savefig(fig_path, dpi=150, bbox_inches='tight')
    plt.show()
    plt.close()

    return {'ari': ari, 'nmi': nmi}


metrics_results = {}

for slice_id in slices_to_process:
    data_dir = SAMPLE1_DATA_DIR if slice_id.startswith('1515') else SAMPLE3_DATA_DIR
    metrics_results[slice_id] = compare_with_ground_truth(slice_id, results, data_dir, OUTPUT_DIR)

# Summary table
print("\n" + "="*60)
print("ARI / NMI Summary (best-by-loss model selection)")
print("="*60)
print(f"{'Slice':<12} {'ARI':<20} {'NMI':<20}")
print("-"*52)
for slice_id, metrics in metrics_results.items():
    print(f"{slice_id:<12} {metrics['ari']:<20.4f} {metrics['nmi']:<20.4f}")
print("-"*52)
avg_ari = np.mean([m['ari'] for m in metrics_results.values()])
avg_nmi = np.mean([m['nmi'] for m in metrics_results.values()])
print(f"{'Mean':<12} {avg_ari:<20.4f} {avg_nmi:<20.4f}")

In [None]:
# =============================================================================
# Oracle model selection: iterate all 30 reps, pick by best ARI
# =============================================================================
# In the unsupervised setting, we would pick by loss.  Because we have ground
# truth labels, we can also do 'oracle' selection: loop over all 30 reps,
# extract domains for each, and pick the one with the highest ARI.  This gives
# an upper bound on how well any single restart can do.

import torch
import shutil


def test_all_reps_for_slice(slice_id, nn_dir, data_dir, num_layers=7):
    """
    Evaluate all 30 random restarts for a slice and return ARI/NMI per rep.

    Parameters
    ----------
    slice_id   : str  e.g. '151507'
    nn_dir     : str  directory containing rep0/ .. rep29/
    data_dir   : str  directory containing {slice_id}.h5ad with ground truth
    num_layers : int  number of domains for DP segmentation

    Returns
    -------
    dict: rep_name -> {'ari', 'nmi', 'loss', 'isodepth_range', 'isodepth_std'}
    """
    print(f"\n{'='*60}")
    print(f"Testing all reps for {slice_id}")
    print(f"{'='*60}")

    # Load ground truth
    adata = sc.read_h5ad(f'{data_dir}/{slice_id}.h5ad')
    gt_labels = adata.obs['original_domain'].astype(str).values
    gt_label_map = {'L1': 0, 'L2': 1, 'L3': 2, 'L4': 3, 'L5': 4, 'L6': 5, 'WM': 6}
    gt_labels_numeric = np.array([gt_label_map.get(l, -1) for l in gt_labels])

    rep_results = {}
    slice_nn_dir = f'{nn_dir}/{slice_id}'

    for rep_idx in range(30):
        rep_name = f'rep{rep_idx}'
        rep_dir  = f'{slice_nn_dir}/{rep_name}'

        if not os.path.exists(f'{rep_dir}/final_model.pt'):
            continue

        try:
            A_torch      = torch.load(f'{rep_dir}/Atorch.pt', weights_only=False)
            S_torch      = torch.load(f'{rep_dir}/Storch.pt', weights_only=False)
            gaston_model = torch.load(f'{rep_dir}/final_model.pt', weights_only=False)

            A_torch      = A_torch.cpu()
            S_torch      = S_torch.cpu()
            gaston_model = gaston_model.cpu()

            A = A_torch.detach().numpy()
            S = S_torch.detach().numpy()

            with open(f'{rep_dir}/min_loss.txt') as f:
                min_loss = float(f.read().strip())

            gaston_isodepth, gaston_labels = dp_related.get_isodepth_labels(
                gaston_model, A, S, num_layers)

            ari = adjusted_rand_score(gt_labels_numeric, gaston_labels)
            nmi = normalized_mutual_info_score(gt_labels_numeric, gaston_labels)

            rep_results[rep_name] = {
                'ari'            : ari,
                'nmi'            : nmi,
                'loss'           : min_loss,
                'isodepth_range' : (gaston_isodepth.min(), gaston_isodepth.max()),
                'isodepth_std'   : gaston_isodepth.std()
            }

            print(f"  {rep_name}: ARI={ari:.4f}, NMI={nmi:.4f}, loss={min_loss:.6f}")

        except Exception as e:
            print(f"  {rep_name}: error — {e}")

    return rep_results


# --- 151674 ---
results_151674 = test_all_reps_for_slice('151674', NN_DIR, SAMPLE3_DATA_DIR)

best_rep_151674 = max(results_151674.items(), key=lambda x: x[1]['ari'])
print(f"\n{'='*60}")
print(f"151674 best rep : {best_rep_151674[0]}")
print(f"  ARI  : {best_rep_151674[1]['ari']:.4f}")
print(f"  NMI  : {best_rep_151674[1]['nmi']:.4f}")
print(f"  Loss : {best_rep_151674[1]['loss']:.6f}")
print(f"{'='*60}")

In [None]:
# =============================================================================
# Save best-by-ARI result for 151674
# =============================================================================

best_rep = best_rep_151674[0]
slice_id = '151674'
rep_dir  = f'{NN_DIR}/{slice_id}/{best_rep}'
dst_dir  = f'{OUTPUT_DIR}/{slice_id}.2'

if os.path.exists(dst_dir):
    shutil.rmtree(dst_dir)
os.makedirs(dst_dir, exist_ok=True)

A_torch      = torch.load(f'{rep_dir}/Atorch.pt', weights_only=False).cpu()
S_torch      = torch.load(f'{rep_dir}/Storch.pt', weights_only=False).cpu()
gaston_model = torch.load(f'{rep_dir}/final_model.pt', weights_only=False).cpu()

A = A_torch.detach().numpy()
S = S_torch.detach().numpy()

gaston_isodepth, gaston_labels = dp_related.get_isodepth_labels(gaston_model, A, S, 7)

adata = sc.read_h5ad(f'{SAMPLE3_DATA_DIR}/{slice_id}.h5ad')
gt_labels = adata.obs['original_domain'].astype(str).values
gt_label_map = {'L1': 0, 'L2': 1, 'L3': 2, 'L4': 3, 'L5': 4, 'L6': 5, 'WM': 6}
gt_labels_numeric = np.array([gt_label_map.get(l, -1) for l in gt_labels])

ari = adjusted_rand_score(gt_labels_numeric, gaston_labels)
nmi = normalized_mutual_info_score(gt_labels_numeric, gaston_labels)

cluster_plotting.plot_isodepth(gaston_isodepth, S, gaston_model,
                               figsize=(7, 6), streamlines=True, cmap='Reds')
plt.title(f'{slice_id} Isodepth ({best_rep})', fontsize=14, fontweight='bold')
plt.savefig(f'{dst_dir}/isodepth_map.png', dpi=150, bbox_inches='tight')
plt.show(); plt.close()

cluster_plotting.plot_clusters(gt_labels_numeric, S, figsize=(6, 6),
                               color_palette=plt.cm.tab10, s=10, lgd=True)
plt.title(f'{slice_id} Ground Truth', fontsize=14, fontweight='bold')
plt.savefig(f'{dst_dir}/ground_truth.png', dpi=150, bbox_inches='tight')
plt.show(); plt.close()

cluster_plotting.plot_clusters(gaston_labels, S, figsize=(6, 6),
                               color_palette=plt.cm.tab10, s=10, lgd=True)
plt.title(f'{slice_id} GASTON ({best_rep}, ARI={ari:.4f}, NMI={nmi:.4f})', fontsize=14, fontweight='bold')
plt.savefig(f'{dst_dir}/gaston_prediction.png', dpi=150, bbox_inches='tight')
plt.show(); plt.close()

cluster_plotting.plot_clusters(gaston_labels, S, figsize=(6, 6),
                               color_palette=plt.cm.tab10, s=10, lgd=True)
plt.title(f'{slice_id} Spatial Domains ({best_rep})', fontsize=14, fontweight='bold')
plt.savefig(f'{dst_dir}/spatial_domains.png', dpi=150, bbox_inches='tight')
plt.show(); plt.close()

np.save(f'{dst_dir}/gaston_isodepth.npy', gaston_isodepth)
np.save(f'{dst_dir}/gaston_labels.npy', gaston_labels)

print(f"\n{'='*60}")
print(f"{slice_id} best result ({best_rep}) saved to: {dst_dir}")
print(f"  ARI: {ari:.4f}")
print(f"  NMI: {nmi:.4f}")
print(f"{'='*60}")

In [None]:
# =============================================================================
# Oracle selection: 151507
# =============================================================================

results_151507 = test_all_reps_for_slice('151507', NN_DIR, SAMPLE1_DATA_DIR)

best_rep_151507 = max(results_151507.items(), key=lambda x: x[1]['ari'])
print(f"\n{'='*60}")
print(f"151507 best rep : {best_rep_151507[0]}")
print(f"  ARI  : {best_rep_151507[1]['ari']:.4f}")
print(f"  NMI  : {best_rep_151507[1]['nmi']:.4f}")
print(f"  Loss : {best_rep_151507[1]['loss']:.6f}")
print(f"{'='*60}")

In [None]:
# =============================================================================
# Save best-by-ARI result for 151507
# =============================================================================

best_rep = best_rep_151507[0]
slice_id = '151507'
rep_dir  = f'{NN_DIR}/{slice_id}/{best_rep}'
dst_dir  = f'{OUTPUT_DIR}/{slice_id}.2'

if os.path.exists(dst_dir):
    shutil.rmtree(dst_dir)
os.makedirs(dst_dir, exist_ok=True)

A_torch      = torch.load(f'{rep_dir}/Atorch.pt', weights_only=False).cpu()
S_torch      = torch.load(f'{rep_dir}/Storch.pt', weights_only=False).cpu()
gaston_model = torch.load(f'{rep_dir}/final_model.pt', weights_only=False).cpu()

A = A_torch.detach().numpy()
S = S_torch.detach().numpy()

gaston_isodepth, gaston_labels = dp_related.get_isodepth_labels(gaston_model, A, S, 7)

adata = sc.read_h5ad(f'{SAMPLE1_DATA_DIR}/{slice_id}.h5ad')
gt_labels = adata.obs['original_domain'].astype(str).values
gt_label_map = {'L1': 0, 'L2': 1, 'L3': 2, 'L4': 3, 'L5': 4, 'L6': 5, 'WM': 6}
gt_labels_numeric = np.array([gt_label_map.get(l, -1) for l in gt_labels])

ari = adjusted_rand_score(gt_labels_numeric, gaston_labels)
nmi = normalized_mutual_info_score(gt_labels_numeric, gaston_labels)

cluster_plotting.plot_isodepth(gaston_isodepth, S, gaston_model,
                               figsize=(7, 6), streamlines=True, cmap='Reds')
plt.title(f'{slice_id} Isodepth ({best_rep})', fontsize=14, fontweight='bold')
plt.savefig(f'{dst_dir}/isodepth_map.png', dpi=150, bbox_inches='tight')
plt.show(); plt.close()

cluster_plotting.plot_clusters(gt_labels_numeric, S, figsize=(6, 6),
                               color_palette=plt.cm.tab10, s=10, lgd=True)
plt.title(f'{slice_id} Ground Truth', fontsize=14, fontweight='bold')
plt.savefig(f'{dst_dir}/ground_truth.png', dpi=150, bbox_inches='tight')
plt.show(); plt.close()

cluster_plotting.plot_clusters(gaston_labels, S, figsize=(6, 6),
                               color_palette=plt.cm.tab10, s=10, lgd=True)
plt.title(f'{slice_id} GASTON ({best_rep}, ARI={ari:.4f}, NMI={nmi:.4f})', fontsize=14, fontweight='bold')
plt.savefig(f'{dst_dir}/gaston_prediction.png', dpi=150, bbox_inches='tight')
plt.show(); plt.close()

cluster_plotting.plot_clusters(gaston_labels, S, figsize=(6, 6),
                               color_palette=plt.cm.tab10, s=10, lgd=True)
plt.title(f'{slice_id} Spatial Domains ({best_rep})', fontsize=14, fontweight='bold')
plt.savefig(f'{dst_dir}/spatial_domains.png', dpi=150, bbox_inches='tight')
plt.show(); plt.close()

np.save(f'{dst_dir}/gaston_isodepth.npy', gaston_isodepth)
np.save(f'{dst_dir}/gaston_labels.npy', gaston_labels)

print(f"\n{'='*60}")
print(f"{slice_id} best result ({best_rep}) saved to: {dst_dir}")
print(f"  ARI: {ari:.4f}")
print(f"  NMI: {nmi:.4f}")
print(f"{'='*60}")

In [None]:
# =============================================================================
# Oracle selection: 151508
# =============================================================================

results_151508 = test_all_reps_for_slice('151508', NN_DIR, SAMPLE1_DATA_DIR)

best_rep_151508 = max(results_151508.items(), key=lambda x: x[1]['ari'])
print(f"\n{'='*60}")
print(f"151508 best rep : {best_rep_151508[0]}")
print(f"  ARI  : {best_rep_151508[1]['ari']:.4f}")
print(f"  NMI  : {best_rep_151508[1]['nmi']:.4f}")
print(f"  Loss : {best_rep_151508[1]['loss']:.6f}")
print(f"{'='*60}")

In [None]:
# =============================================================================
# Save best-by-ARI result for 151508
# =============================================================================

best_rep = best_rep_151508[0]
slice_id = '151508'
rep_dir  = f'{NN_DIR}/{slice_id}/{best_rep}'
dst_dir  = f'{OUTPUT_DIR}/{slice_id}.2'

if os.path.exists(dst_dir):
    shutil.rmtree(dst_dir)
os.makedirs(dst_dir, exist_ok=True)

A_torch      = torch.load(f'{rep_dir}/Atorch.pt', weights_only=False).cpu()
S_torch      = torch.load(f'{rep_dir}/Storch.pt', weights_only=False).cpu()
gaston_model = torch.load(f'{rep_dir}/final_model.pt', weights_only=False).cpu()

A = A_torch.detach().numpy()
S = S_torch.detach().numpy()

gaston_isodepth, gaston_labels = dp_related.get_isodepth_labels(gaston_model, A, S, 7)

adata = sc.read_h5ad(f'{SAMPLE1_DATA_DIR}/{slice_id}.h5ad')
gt_labels = adata.obs['original_domain'].astype(str).values
gt_label_map = {'L1': 0, 'L2': 1, 'L3': 2, 'L4': 3, 'L5': 4, 'L6': 5, 'WM': 6}
gt_labels_numeric = np.array([gt_label_map.get(l, -1) for l in gt_labels])

ari = adjusted_rand_score(gt_labels_numeric, gaston_labels)
nmi = normalized_mutual_info_score(gt_labels_numeric, gaston_labels)

cluster_plotting.plot_isodepth(gaston_isodepth, S, gaston_model,
                               figsize=(7, 6), streamlines=True, cmap='Reds')
plt.title(f'{slice_id} Isodepth ({best_rep})', fontsize=14, fontweight='bold')
plt.savefig(f'{dst_dir}/isodepth_map.png', dpi=150, bbox_inches='tight')
plt.show(); plt.close()

cluster_plotting.plot_clusters(gt_labels_numeric, S, figsize=(6, 6),
                               color_palette=plt.cm.tab10, s=10, lgd=True)
plt.title(f'{slice_id} Ground Truth', fontsize=14, fontweight='bold')
plt.savefig(f'{dst_dir}/ground_truth.png', dpi=150, bbox_inches='tight')
plt.show(); plt.close()

cluster_plotting.plot_clusters(gaston_labels, S, figsize=(6, 6),
                               color_palette=plt.cm.tab10, s=10, lgd=True)
plt.title(f'{slice_id} GASTON ({best_rep}, ARI={ari:.4f}, NMI={nmi:.4f})', fontsize=14, fontweight='bold')
plt.savefig(f'{dst_dir}/gaston_prediction.png', dpi=150, bbox_inches='tight')
plt.show(); plt.close()

cluster_plotting.plot_clusters(gaston_labels, S, figsize=(6, 6),
                               color_palette=plt.cm.tab10, s=10, lgd=True)
plt.title(f'{slice_id} Spatial Domains ({best_rep})', fontsize=14, fontweight='bold')
plt.savefig(f'{dst_dir}/spatial_domains.png', dpi=150, bbox_inches='tight')
plt.show(); plt.close()

np.save(f'{dst_dir}/gaston_isodepth.npy', gaston_isodepth)
np.save(f'{dst_dir}/gaston_labels.npy', gaston_labels)

print(f"\n{'='*60}")
print(f"{slice_id} best result ({best_rep}) saved to: {dst_dir}")
print(f"  ARI: {ari:.4f}")
print(f"  NMI: {nmi:.4f}")
print(f"{'='*60}")

In [None]:
# =============================================================================
# Oracle selection: 151673
# =============================================================================

results_151673 = test_all_reps_for_slice('151673', NN_DIR, SAMPLE3_DATA_DIR)

best_rep_151673 = max(results_151673.items(), key=lambda x: x[1]['ari'])
print(f"\n{'='*60}")
print(f"151673 best rep : {best_rep_151673[0]}")
print(f"  ARI  : {best_rep_151673[1]['ari']:.4f}")
print(f"  NMI  : {best_rep_151673[1]['nmi']:.4f}")
print(f"  Loss : {best_rep_151673[1]['loss']:.6f}")
print(f"{'='*60}")

In [None]:
# =============================================================================
# Save best-by-ARI result for 151673
# =============================================================================

best_rep = best_rep_151673[0]
slice_id = '151673'
rep_dir  = f'{NN_DIR}/{slice_id}/{best_rep}'
dst_dir  = f'{OUTPUT_DIR}/{slice_id}.2'

if os.path.exists(dst_dir):
    shutil.rmtree(dst_dir)
os.makedirs(dst_dir, exist_ok=True)

A_torch      = torch.load(f'{rep_dir}/Atorch.pt', weights_only=False).cpu()
S_torch      = torch.load(f'{rep_dir}/Storch.pt', weights_only=False).cpu()
gaston_model = torch.load(f'{rep_dir}/final_model.pt', weights_only=False).cpu()

A = A_torch.detach().numpy()
S = S_torch.detach().numpy()

gaston_isodepth, gaston_labels = dp_related.get_isodepth_labels(gaston_model, A, S, 7)

adata = sc.read_h5ad(f'{SAMPLE3_DATA_DIR}/{slice_id}.h5ad')
gt_labels = adata.obs['original_domain'].astype(str).values
gt_label_map = {'L1': 0, 'L2': 1, 'L3': 2, 'L4': 3, 'L5': 4, 'L6': 5, 'WM': 6}
gt_labels_numeric = np.array([gt_label_map.get(l, -1) for l in gt_labels])

ari = adjusted_rand_score(gt_labels_numeric, gaston_labels)
nmi = normalized_mutual_info_score(gt_labels_numeric, gaston_labels)

cluster_plotting.plot_isodepth(gaston_isodepth, S, gaston_model,
                               figsize=(7, 6), streamlines=True, cmap='Reds')
plt.title(f'{slice_id} Isodepth ({best_rep})', fontsize=14, fontweight='bold')
plt.savefig(f'{dst_dir}/isodepth_map.png', dpi=150, bbox_inches='tight')
plt.show(); plt.close()

cluster_plotting.plot_clusters(gt_labels_numeric, S, figsize=(6, 6),
                               color_palette=plt.cm.tab10, s=10, lgd=True)
plt.title(f'{slice_id} Ground Truth', fontsize=14, fontweight='bold')
plt.savefig(f'{dst_dir}/ground_truth.png', dpi=150, bbox_inches='tight')
plt.show(); plt.close()

cluster_plotting.plot_clusters(gaston_labels, S, figsize=(6, 6),
                               color_palette=plt.cm.tab10, s=10, lgd=True)
plt.title(f'{slice_id} GASTON ({best_rep}, ARI={ari:.4f}, NMI={nmi:.4f})', fontsize=14, fontweight='bold')
plt.savefig(f'{dst_dir}/gaston_prediction.png', dpi=150, bbox_inches='tight')
plt.show(); plt.close()

cluster_plotting.plot_clusters(gaston_labels, S, figsize=(6, 6),
                               color_palette=plt.cm.tab10, s=10, lgd=True)
plt.title(f'{slice_id} Spatial Domains ({best_rep})', fontsize=14, fontweight='bold')
plt.savefig(f'{dst_dir}/spatial_domains.png', dpi=150, bbox_inches='tight')
plt.show(); plt.close()

np.save(f'{dst_dir}/gaston_isodepth.npy', gaston_isodepth)
np.save(f'{dst_dir}/gaston_labels.npy', gaston_labels)

print(f"\n{'='*60}")
print(f"{slice_id} best result ({best_rep}) saved to: {dst_dir}")
print(f"  ARI: {ari:.4f}")
print(f"  NMI: {nmi:.4f}")
print(f"{'='*60}")

In [None]:
# =============================================================================
# Summary table — oracle (best-by-ARI) results across all slices
# =============================================================================

print("\n" + "="*70)
print("Oracle selection summary (best-by-ARI across 30 restarts)")
print("="*70)
print(f"{'Slice':<12} {'Best rep':<10} {'ARI':<12} {'NMI':<12} {'Loss':<12}")
print("-"*58)

print(f"{'151507':<12} {best_rep_151507[0]:<10} {best_rep_151507[1]['ari']:<12.4f} {best_rep_151507[1]['nmi']:<12.4f} {best_rep_151507[1]['loss']:<12.6f}")
print(f"{'151508':<12} {best_rep_151508[0]:<10} {best_rep_151508[1]['ari']:<12.4f} {best_rep_151508[1]['nmi']:<12.4f} {best_rep_151508[1]['loss']:<12.6f}")
print(f"{'151673':<12} {best_rep_151673[0]:<10} {best_rep_151673[1]['ari']:<12.4f} {best_rep_151673[1]['nmi']:<12.4f} {best_rep_151673[1]['loss']:<12.6f}")
print(f"{'151674':<12} {best_rep_151674[0]:<10} {best_rep_151674[1]['ari']:<12.4f} {best_rep_151674[1]['nmi']:<12.4f} {best_rep_151674[1]['loss']:<12.6f}")

print("-"*58)

avg_ari = (best_rep_151507[1]['ari'] + best_rep_151508[1]['ari'] +
           best_rep_151673[1]['ari'] + best_rep_151674[1]['ari']) / 4
avg_nmi = (best_rep_151507[1]['nmi'] + best_rep_151508[1]['nmi'] +
           best_rep_151673[1]['nmi'] + best_rep_151674[1]['nmi']) / 4
print(f"{'Mean':<12} {'':<10} {avg_ari:<12.4f} {avg_nmi:<12.4f}")

print("\n" + "="*70)
print("Best results saved to:")
print(f"  {OUTPUT_DIR}/151507.2/")
print(f"  {OUTPUT_DIR}/151508.2/")
print(f"  {OUTPUT_DIR}/151673.2/")
print(f"  {OUTPUT_DIR}/151674.2/")
print("="*70)

In [None]:
# =============================================================================
# Summary
# =============================================================================
# Each slice now has:
#   results/{slice_id}/                 <- best-by-loss
#     model_selection.png              log-likelihood curve
#     isodepth_map.png                 continuous isodepth field
#     spatial_domains.png              discrete domain labels
#     ground_truth.png                 expert annotation
#     gaston_prediction.png            GASTON vs. ground truth
#     gaston_isodepth.npy              saved isodepth array
#     gaston_labels.npy                saved domain label array
#   results/{slice_id}.2/              <- best-by-ARI (oracle)
#     (same set of files)
#
# Next step: spatially-variable gene analysis using the isodepth as a
# continuous pseudo-spatial coordinate.

print("=" * 60)
print("Step 4 complete.  Output structure:")
print("=" * 60)
for sid in slices_to_process:
    print(f"  {OUTPUT_DIR}/{sid}/")
    print(f"    model_selection.png  isodepth_map.png  spatial_domains.png")
    print(f"    ground_truth.png     gaston_prediction.png")
    print(f"    gaston_isodepth.npy  gaston_labels.npy")
print()
print("Next step: spatially-variable gene analysis.")