# Task 3: Interpret the Embedding Space

This notebook loads baseline and perturbation embeddings and computes:
- Δ distance to healthy centroid
- 1D Wasserstein distance (ALS vs healthy) before/after
- kNN overlap gain
- Silhouette score change

It also produces UMAP visualizations and centroid shift plots.

Outputs saved to `data/figs/`: `task3_umap.png`, `task3_centroid_shifts.png`, and `task3_metrics.csv`.


In [None]:
# Ensure repo root is on sys.path for `utils` imports
from pathlib import Path
import sys
repo_root = Path.cwd().parent
if str(repo_root) not in sys.path:
    sys.path.insert(0, str(repo_root))
print('Added to sys.path:', repo_root)


In [None]:
from pathlib import Path
import numpy as np
import pandas as pd

from utils.metrics import (
    centroid, delta_to_healthy, wasserstein1d_along_pc, knn_overlap_fraction,
    silhouette_scores_by_label, composite_score
)
from utils.plotting import umap_2d, plot_umap, plot_centroid_shifts

EMB_DIR = Path('data/embeddings')
FIG_DIR = Path('data/figs')
FIG_DIR.mkdir(parents=True, exist_ok=True)

# Load embeddings saved in Task 2
healthy_path = EMB_DIR / 'healthy_base.npz'
als_path = EMB_DIR / 'als_base.npz'
if not (healthy_path.exists() and als_path.exists()):
    raise FileNotFoundError("Missing healthy_base.npz or als_base.npz. Run notebook 02 first.")

healthy = np.load(healthy_path)['arr']
als = np.load(als_path)['arr']

# Load perturbation embeddings
pert_embeddings = {}
for p in sorted(EMB_DIR.glob('healthy_*_up.npz')):
    pert_embeddings[p.stem.replace('healthy_', '')] = np.load(p)['arr']
for p in sorted(EMB_DIR.glob('als_*_down.npz')):
    pert_embeddings[p.stem.replace('als_', '')] = np.load(p)['arr']

# Metrics
rows = []
for name, emb_als_pert in pert_embeddings.items():
    d_health = delta_to_healthy(als, emb_als_pert, healthy)
    w1d = wasserstein1d_along_pc(als, healthy)
    w1d_after = wasserstein1d_along_pc(emb_als_pert, healthy)
    knn_gain = knn_overlap_fraction(healthy, emb_als_pert, k=15) - knn_overlap_fraction(healthy, als, k=15)
    sil_base = silhouette_scores_by_label(np.vstack([healthy, als]), ['healthy']*len(healthy)+['als']*len(als))
    sil_after = silhouette_scores_by_label(np.vstack([healthy, emb_als_pert]), ['healthy']*len(healthy)+['als_pert']*len(emb_als_pert))
    rows.append({
        'perturbation': name,
        'delta_to_healthy': d_health,
        'wasserstein_before': w1d,
        'wasserstein_after': w1d_after,
        'knn_overlap_gain': knn_gain,
        'silhouette_before': sil_base,
        'silhouette_after': sil_after,
    })

metrics_df = pd.DataFrame(rows)
metrics_df.to_csv(FIG_DIR / 'task3_metrics.csv', index=False)

# Visualizations
# UMAP of pooled sets
pool = [healthy, als] + list(pert_embeddings.values())
labels = (['healthy']*len(healthy) + ['als']*len(als) +
          sum([[k]*len(v) for k,v in pert_embeddings.items()], []))
X = np.vstack(pool)
pts = umap_2d(X)
plot_umap(pts, labels, 'Task3: UMAP pooled', str(FIG_DIR / 'task3_umap.png'))

# Centroid shifts plot
plot_centroid_shifts(healthy, als, pert_embeddings, 'Task3: Centroid shifts', str(FIG_DIR / 'task3_centroid_shifts.png'))

metrics_df.head()
