# Step 3: Training the GASTON Neural Network

## What is this step doing?

After Step 2, each spot $i$ has a **14-dimensional embedding** $\mathbf{a}_i \in \mathbb{R}^{14}$ (from GLM-PCA) and a **2D spatial coordinate** $(x_i, y_i)$.  
The question GASTON asks is:

> *Can we explain most of the variation in $\mathbf{a}_i$ using just a **single number** derived from the spot's spatial position?*

That single number is called the **isodepth** $d_i$ — a learned 1D coordinate along the tissue's main axis of expression change (think of it as a "depth into the cortex" axis, but learned entirely from data).

---

## The model: an extreme-bottleneck autoencoder

GASTON is a **bottleneck autoencoder** with bottleneck width = **1**:

```
  (x_i, y_i)                        a_i  ∈ R^14
      │                               ▲
      ▼                               │
  Encoder φ_θ                    Decoder h_ψ
  MLP [2 → 20 → 20 → 1]         MLP [1 → 20 → 20 → 14]
      │                               │
      └──────► d_i  ∈ R ─────────────┘
                (isodepth)
```

**Encoder** $\phi_\theta : \mathbb{R}^2 \to \mathbb{R}$  
Maps each spot's (x, y) coordinates to a scalar isodepth. Architecture: 2 hidden layers of 20 ReLU units.

**Decoder** $h_\psi : \mathbb{R} \to \mathbb{R}^{14}$  
Reconstructs the GLM-PCA embedding from the scalar isodepth. Same architecture.

**Loss** (MSE reconstruction):
$$\mathcal{L} = \frac{1}{N} \sum_{i=1}^{N} \| \mathbf{a}_i - h_\psi(\phi_\theta(x_i, y_i)) \|_2^2$$

This is a standard **autoencoder reconstruction loss** from EECS 545.

---

## Why width-1 bottleneck? What does it force the network to learn?

A width-1 bottleneck forces the encoder to **rank all spots on a single axis** such that the decoder can reconstruct the 14-dim embedding from that rank alone.  
Spots that share the same tissue layer will end up with similar $d_i$ values — not because we told the network about layers, but because their gene expression profiles $\mathbf{a}_i$ are similar, and the only way the decoder can reconstruct them well is to assign them nearby isodepths.

Geometrically: the decoder $h_\psi$ is a **piecewise-linear curve** in $\mathbb{R}^{14}$ (ReLU network from $\mathbb{R}^1$), and the encoder is learning to project each spot onto the nearest point on that curve. This is analogous to **1D nonlinear PCA** (a.k.a. a principal curve).

---

## Design choices — other things worth trying

This architecture is **not the only option**. Here are natural variants to consider:

| What to change | Original | Alternative | Effect |
|---|---|---|---|
| **Bottleneck width** | 1 | 2 | 2D isodepth — captures more structure, harder to interpret |
| **Network depth/width** | `[20, 20]` | `[64, 64, 64]` | More capacity, risk of overfitting on ~4k points |
| **Loss function** | MSE on $\mathbf{A}$ | Cosine similarity, or Poisson NLL | Different inductive bias on the embedding |
| **Spatial regularization** | None | Add $\lambda \sum_{(i,j) \text{ neighbors}} (d_i - d_j)^2$ | Encourages spatially smooth isodepth (like a graph Laplacian regularizer) |
| **Input to encoder** | (x, y) only | (x, y, H&E patch features) | **This is exactly what C-GASTON adds** |
| **Random restarts** | 30 | More/fewer | More restarts = better chance of escaping local minima |

The last row is the key motivation for the C-GASTON project: the encoder currently ignores the H&E image, which contains rich morphological information about tissue structure.

---

## Why multiple random restarts?

The MSE loss of a bottleneck autoencoder is **non-convex** — Adam can get stuck in different local minima depending on initialization. We run **30 independent training runs** with different random seeds and select the one with the lowest final loss. This is the same idea as running k-means with multiple initializations.

In [None]:
# =============================================================================
# Step 3: Training the GASTON Bottleneck Autoencoder
# =============================================================================
# EECS 545 Project: C-GASTON
# Reference: https://gaston-tutorial.readthedocs.io/

import os
import numpy as np
import matplotlib.pyplot as plt
import torch

from gaston import neural_net

plt.rcParams['figure.dpi'] = 150
plt.rcParams['savefig.dpi'] = 150

print("Libraries imported.")
print(f"PyTorch version : {torch.__version__}")
print(f"CUDA available  : {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU             : {torch.cuda.get_device_name(0)}")

In [None]:
# =============================================================================
# Paths
# =============================================================================

BASE_DIR = '/home/siruilf/A_new_dataset_for_gaston'

# Input: GLM-PCA embeddings from Step 2
GLMPCA_DIR = f'{BASE_DIR}/2.GLM_PC/glmpca_results'

# Output: trained model checkpoints
NN_OUTPUT_DIR = f'{BASE_DIR}/3.Training_Gaston__NN/nn_results'
os.makedirs(NN_OUTPUT_DIR, exist_ok=True)

print(f"GLM-PCA input  : {GLMPCA_DIR}")
print(f"NN output      : {NN_OUTPUT_DIR}")

In [None]:
# =============================================================================
# Slices to train
# =============================================================================

slices_to_train = ['151507', '151508', '151673', '151674']

print("Checking data availability...")
for sid in slices_to_train:
    ok = (
        os.path.exists(f'{GLMPCA_DIR}/{sid}/glmpca.npy') and
        os.path.exists(f'{GLMPCA_DIR}/{sid}/coords_mat.npy')
    )
    print(f"  {sid}: {'ready' if ok else 'MISSING — run Step 2 first'}")

In [None]:
# =============================================================================
# Hyperparameters
# =============================================================================
#
# isodepth_arch / expression_arch
# --------------------------------
# Both encoder and decoder are 2-hidden-layer MLPs with 20 units each and
# ReLU activations.  The small size is intentional: ~4k training points
# with a simple 1D bottleneck do not need a large network.  A larger network
# would overfit and produce a non-smooth isodepth.
#
# epochs
# ------
# Full-batch gradient descent for 10,000 steps.  Because N ~4k fits easily
# in GPU memory, there is no mini-batching — each step uses all spots.
# (Mini-batch training is an option for larger datasets.)
#
# num_restarts
# ------------
# The loss landscape is non-convex (ReLU network, bottleneck).  We run 30
# independent training runs with different random seeds and keep the model
# with the lowest final reconstruction loss.  This is the same strategy as
# multi-start k-means.
#
# optimizer
# ---------
# Adam with default learning rate (1e-3).  Standard choice for MLPs.

isodepth_arch   = [20, 20]    # encoder: R^2 -> 20 -> 20 -> R^1
expression_arch = [20, 20]    # decoder: R^1 -> 20 -> 20 -> R^14
epochs          = 10000
checkpoint      = 500         # save model state every 500 epochs
optimizer       = 'adam'
num_restarts    = 30
device          = 'cuda' if torch.cuda.is_available() else 'cpu'

print("Hyperparameters:")
print(f"  encoder arch    : R^2 -> {isodepth_arch} -> R^1")
print(f"  decoder arch    : R^1 -> {expression_arch} -> R^14")
print(f"  epochs          : {epochs}")
print(f"  checkpoint every: {checkpoint} epochs")
print(f"  optimizer       : {optimizer}")
print(f"  random restarts : {num_restarts}")
print(f"  device          : {device}")

In [None]:
# =============================================================================
# Training function
# =============================================================================

def train_gaston(slice_id, glmpca_dir, output_dir,
                 isodepth_arch, expression_arch,
                 epochs, checkpoint, optimizer,
                 num_restarts, device):
    """
    Train the GASTON bottleneck autoencoder for one tissue slice.

    Data flow
    ---------
    Load  A (N, 14) and S (N, 2)  from GLM-PCA step
      |-- z-score normalize both  (neural_net.load_rescale_input_data)
      |-- run `num_restarts` independent training runs
      |-- each run: Adam for `epochs` steps, full-batch MSE loss
      |-- save checkpoint every `checkpoint` steps + final model

    The best model is selected in Step 4 (Process_NN_Output) by comparing
    the final loss across all restarts.

    Parameters
    ----------
    slice_id       : str  e.g. '151507'
    glmpca_dir     : str  directory containing glmpca.npy and coords_mat.npy
    output_dir     : str  where to save model checkpoints
    isodepth_arch  : list encoder hidden layer sizes
    expression_arch: list decoder hidden layer sizes
    epochs         : int  training steps per restart
    checkpoint     : int  save frequency in epochs
    optimizer      : str  'adam'
    num_restarts   : int  number of independent random initializations
    device         : str  'cuda' or 'cpu'
    """
    print(f"\n{'='*60}")
    print(f"Slice: {slice_id}")
    print(f"{'='*60}")

    # --- Load GLM-PCA outputs ---
    # A: the 14-dim embeddings — reconstruction TARGET for the decoder
    # S: the (x, y) coordinates  — INPUT to the encoder
    print("[1/3] Loading GLM-PCA data...")
    A = np.load(f'{glmpca_dir}/{slice_id}/glmpca.npy')
    S = np.load(f'{glmpca_dir}/{slice_id}/coords_mat.npy')
    print(f"      A (embeddings) : {A.shape}  (N x K)")
    print(f"      S (coordinates): {S.shape}  (N x 2)")

    # --- z-score normalize ---
    # Normalize each feature to zero mean and unit variance.
    # This is standard preprocessing before training an MLP — without it,
    # features on very different scales cause unstable gradients.
    # (Same as sklearn's StandardScaler, applied to both S and A.)
    print("\n[2/3] z-score normalizing inputs...")
    S_torch, A_torch = neural_net.load_rescale_input_data(S, A)
    print(f"      S_torch : {tuple(S_torch.shape)}  mean≈0, std≈1 per coordinate")
    print(f"      A_torch : {tuple(A_torch.shape)}  mean≈0, std≈1 per dimension")

    # --- Train num_restarts independent models ---
    # Each restart uses a fresh random weight initialization (different seed).
    # The Adam optimizer runs for `epochs` full-batch gradient steps.
    # We save a checkpoint every `checkpoint` epochs so we can inspect
    # convergence or resume training if needed.
    print(f"\n[3/3] Training {num_restarts} restarts x {epochs} epochs...")
    slice_out = f'{output_dir}/{slice_id}'
    os.makedirs(slice_out, exist_ok=True)

    for seed in range(num_restarts):
        rep_dir = f'{slice_out}/rep{seed}'
        os.makedirs(rep_dir, exist_ok=True)

        mod, loss_list = neural_net.train(
            S_torch, A_torch,
            S_hidden_list=isodepth_arch,
            A_hidden_list=expression_arch,
            epochs=epochs,
            checkpoint=checkpoint,
            device=device,
            save_dir=rep_dir,
            optim=optimizer,
            seed=seed,
            save_final=True
        )
        print(f"  rep{seed:02d}  final loss = {loss_list[-1]:.6f}")

    print(f"\nSlice {slice_id} done.  Results in: {slice_out}")
    return slice_out


print("Function defined.")

In [None]:
# =============================================================================
# Train all slices
# =============================================================================
# Each slice: 30 restarts x 10,000 epochs = 300,000 gradient steps.
# Typical runtime: ~5-15 min per slice on a modern GPU.
#
# The best model per slice is NOT selected here — that happens in
# Step 4 (Process_NN_Output), where we compare all 30 restarts by
# their final reconstruction loss and optionally by ARI against ground truth.

trained_paths = {}

for slice_id in slices_to_train:
    trained_paths[slice_id] = train_gaston(
        slice_id       = slice_id,
        glmpca_dir     = GLMPCA_DIR,
        output_dir     = NN_OUTPUT_DIR,
        isodepth_arch  = isodepth_arch,
        expression_arch= expression_arch,
        epochs         = epochs,
        checkpoint     = checkpoint,
        optimizer      = optimizer,
        num_restarts   = num_restarts,
        device         = device,
    )

print(f"\n{'='*60}")
print("All slices trained.")
print(f"{'='*60}")

In [None]:
# =============================================================================
# Summary
# =============================================================================
# Each slice now has 30 independently trained bottleneck autoencoders.
# The directory structure is:
#
#   nn_results/
#   └── {slice_id}/
#         ├── rep0/
#         │    ├── Atorch.pt          <- normalized A (same for all reps)
#         │    ├── Storch.pt          <- normalized S (same for all reps)
#         │    ├── final_model.pt     <- trained model weights
#         │    ├── min_loss.txt       <- final reconstruction loss
#         │    └── model_epoch_*.pt   <- checkpoints every 500 epochs
#         ├── rep1/
#         ├── ...
#         └── rep29/
#
# Next step (Step 4) loads all 30 models, selects the best by loss or ARI,
# then runs the dp_related module to extract the isodepth and domain labels.

print("=" * 60)
print("Training complete.  File structure:")
print("=" * 60)
for sid in slices_to_train:
    print(f"  {NN_OUTPUT_DIR}/{sid}/")
    print(f"    rep0/ .. rep{num_restarts-1}/")
    print(f"      final_model.pt    <- autoencoder weights")
    print(f"      min_loss.txt      <- final MSE loss (used to pick best rep)")

print()
print("Next step: Process_NN_Output — pick best restart, extract isodepth & domains.")