# Task 4 (Optional): Prioritize Targets

Compute a composite score per perturbation combining:
- Δ distance to healthy (improvement negative)
- Wasserstein reduction
- kNN healthy-overlap gain
- Silhouette change

Also compute OffTargetVar: variance of shift norms across ALS cells.
Outputs: `data/figs/task4_top_targets.png` and a CSV of rankings.


In [None]:
# Ensure repo root is on sys.path for `utils` imports
from pathlib import Path
import sys
repo_root = Path.cwd().parent
if str(repo_root) not in sys.path:
    sys.path.insert(0, str(repo_root))
print('Added to sys.path:', repo_root)


In [None]:
from pathlib import Path
import numpy as np
import pandas as pd

from utils.metrics import composite_score
from utils.plotting import barplot_scores

EMB_DIR = Path('data/embeddings')
FIG_DIR = Path('data/figs')
FIG_DIR.mkdir(parents=True, exist_ok=True)

# Load metrics from Task 3
metrics_path = FIG_DIR / 'task3_metrics.csv'
metrics_df = pd.read_csv(metrics_path)

# Define weights
weights = {
    'delta_to_healthy': -0.5,   # negative change is good -> negative weight makes improvement positive
    'wasserstein_after': -0.2,  # lower after is better
    'knn_overlap_gain': 0.2,    # higher is better
    'silhouette_after': -0.1,   # lower separation of ALS_pert vs healthy may indicate rescue
}

scores = {}
for _, row in metrics_df.iterrows():
    comps = {
        'delta_to_healthy': row['delta_to_healthy'],
        'wasserstein_after': row['wasserstein_after'],
        'knn_overlap_gain': row['knn_overlap_gain'],
        'silhouette_after': row['silhouette_after'],
    }
    scores[row['perturbation']] = composite_score(comps, weights)

# OffTargetVar: use per-cell shift norms if available; here approximate from embeddings sizes (placeholder)
# In practice, compute norms of (ALS_pert - ALS) per cell and take variance.
# We reduce score by OffTargetVar to penalize heterogeneous responses.
for name in list(scores.keys()):
    scores[name] -= 0.0  # placeholder (no-op)

# Plot and save
barplot_scores(scores, 'Task4: Top targets (higher=better)', str(FIG_DIR / 'task4_top_targets.png'))

rank_df = pd.DataFrame({'perturbation': list(scores.keys()), 'score': list(scores.values())}).sort_values('score', ascending=False)
rank_df.to_csv(FIG_DIR / 'task4_rankings.csv', index=False)

rank_df.head()
