# ALADYN: Centered (nolr) vs Non-Centered (reparam) Pipeline

Reproduces the full pipeline: **Training -> Pooling -> LOO Prediction -> AUC evaluation**

Each step uses `%run` or `subprocess` on the corresponding script.

## Pipeline overview

```
Step 1: Train on 40 batches (10k each) -> batch checkpoints with phi, psi, kappa, gamma
Step 2: Pool params across batches -> single pooled_phi_kappa_gamma_{type}.pt
Step 3: LOO Prediction -> for each of 5 eval batches, pool params from all EXCEPT that batch,
        then optimize lambda (nolr) or delta (reparam) -> pi tensors
Step 4: Evaluate AUC (static 10yr, dynamic 10yr, dynamic 1yr) -> comparison CSV
```

**Key scripts:**
- Training: `run_aladyn_batch_vector_e_censor_nolor.py` (nolr), `run_aladyn_batch_vector_e_censor_nolor_reparam.py` (reparam)
- Pooling: `pool_phi_kappa_gamma_from_batches.py`
- LOO Prediction: `run_loo_predict_both.py`
- AUC Evaluation: `compare_loo_auc.py`

---
## Step 0: Paths & Setup

In [None]:
import os

# Base directories
SCRIPT_DIR   = '/Users/sarahurbut/aladynoulli2/claudefile'
DATA_DIR     = '/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/'
DROPBOX      = '/Users/sarahurbut/Library/CloudStorage/Dropbox/'
COV_PATH     = '/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/baselinagefamh_withpcs.csv'

# Training output (batch checkpoints)
NOLR_TRAIN_DIR    = DROPBOX + 'censor_e_batchrun_vectorized_nolr'
REPARAM_TRAIN_DIR = DROPBOX + 'censor_e_batchrun_vectorized_REPARAM'

# Pooled parameters
POOLED_NOLR    = DATA_DIR + 'pooled_phi_kappa_gamma_nolr.pt'
POOLED_REPARAM = DATA_DIR + 'pooled_phi_kappa_gamma_reparam.pt'
NOLR_MASTER    = DATA_DIR + 'master_for_fitting_pooled_correctedE.pt'
NOLR_GK        = DATA_DIR + 'pooled_kappa_gamma_nolr.pt'

# LOO Prediction output (pi tensors) — used for final AUC comparison
LOO_NOLR_DIR    = DROPBOX + 'enrollment_predictions_fixedphi_fixedgk_nolr_loo/'
LOO_REPARAM_DIR = DROPBOX + 'enrollment_predictions_fixedphi_fixedgk_reparam_loo/'

# Non-LOO prediction output (for reference only — LOO is the proper comparison)
NOLR_PRED_DIR    = DROPBOX + 'enrollment_predictions_fixedphi_fixedgk_nolr_vectorized/'
REPARAM_PRED_DIR = DROPBOX + 'enrollment_predictions_fixedphi_fixedgk_reparam_vectorized/'

print('Paths configured.')
for label, path in [('NOLR train', NOLR_TRAIN_DIR), ('REPARAM train', REPARAM_TRAIN_DIR),
                     ('LOO NOLR pred', LOO_NOLR_DIR), ('LOO REPARAM pred', LOO_REPARAM_DIR)]:
    exists = os.path.exists(path)
    print(f'  {label}: {"EXISTS" if exists else "MISSING"} -- {path}')

Paths configured.
  NOLR train: EXISTS -- /Users/sarahurbut/Library/CloudStorage/Dropbox/censor_e_batchrun_vectorized_nolr
  REPARAM train: EXISTS -- /Users/sarahurbut/Library/CloudStorage/Dropbox/censor_e_batchrun_vectorized_REPARAM
  LOO NOLR pred: EXISTS -- /Users/sarahurbut/Library/CloudStorage/Dropbox/enrollment_predictions_fixedphi_fixedgk_nolr_loo/
  LOO REPARAM pred: EXISTS -- /Users/sarahurbut/Library/CloudStorage/Dropbox/enrollment_predictions_fixedphi_fixedgk_reparam_loo/


---
## Step 1: Batch Training (already done)

Training was run via shell scripts that loop over 40 batches of 10k patients (0–400k).
Each batch trains the full model (phi, psi, kappa, gamma, lambda) for 200 epochs.

**Nolr** (centered): `run_aladyn_batch_vector_e_censor_nolor.py`  
**Reparam** (non-centered): `run_aladyn_batch_vector_e_censor_nolor_reparam.py`

The shell scripts are `run_all_batches_reparam.sh` (and the deleted `run_all_batches_nolr.sh`).

To re-run a single batch (e.g., batch 0–10k, reparam):

In [None]:
# Example: train ONE reparam batch (0-10k). Remove the `if False` to actually run.
if False:
    import subprocess, sys
    subprocess.run([sys.executable, f'{SCRIPT_DIR}/run_aladyn_batch_vector_e_censor_nolor_reparam.py',
        '--start_index', '0', '--end_index', '10000',
        '--num_epochs', '200', '--learning_rate', '0.1',
        '--K', '20', '--W', '0.0001',
        '--data_dir', DATA_DIR, '--covariates_path', COV_PATH,
        '--output_dir', REPARAM_TRAIN_DIR])

In [None]:
# Check how many batch checkpoints exist
import glob
nolr_ckpts = sorted(glob.glob(f'{NOLR_TRAIN_DIR}/enrollment_model_VECTORIZED_W0.0001_nolr_batch_*_*.pt'))
reparam_ckpts = sorted(glob.glob(f'{REPARAM_TRAIN_DIR}/enrollment_model_REPARAM_W0.0001_batch_*_*.pt'))
print(f'Nolr batch checkpoints:    {len(nolr_ckpts)}')
print(f'Reparam batch checkpoints: {len(reparam_ckpts)}')

Nolr batch checkpoints:    40
Reparam batch checkpoints: 40


---
## Step 2: Pool Parameters Across Batches

Average phi, psi, kappa, gamma across the 39 training batches (batch 40 = holdout).
Produces a single `pooled_phi_kappa_gamma_{nolr|reparam}.pt`.

Script: `pool_phi_kappa_gamma_from_batches.py`

In [None]:
# Pool NOLR params (skip if file already exists)
import subprocess, sys
if not os.path.exists(POOLED_NOLR):
    subprocess.run([sys.executable, f'{SCRIPT_DIR}/pool_phi_kappa_gamma_from_batches.py',
        '--model_type', 'nolr', '--max_batches', '39',
        '--nolr_dir', NOLR_TRAIN_DIR, '--output_dir', DATA_DIR])
else:
    print(f'Already exists: {POOLED_NOLR}')

Already exists: /Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/pooled_phi_kappa_gamma_nolr.pt


In [None]:
# Pool REPARAM params (skip if file already exists)
if not os.path.exists(POOLED_REPARAM):
    subprocess.run([sys.executable, f'{SCRIPT_DIR}/pool_phi_kappa_gamma_from_batches.py',
        '--model_type', 'reparam', '--max_batches', '39',
        '--reparam_dir', REPARAM_TRAIN_DIR, '--output_dir', DATA_DIR])
else:
    print(f'Already exists: {POOLED_REPARAM}')

Already exists: /Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/pooled_phi_kappa_gamma_reparam.pt


In [None]:
# Inspect pooled params
import torch
import numpy as np

for label, path in [('NOLR', POOLED_NOLR), ('REPARAM', POOLED_REPARAM)]:
    if os.path.exists(path):
        d = torch.load(path, weights_only=False)
        phi = d['phi']
        gamma = d['gamma']
        kappa = d['kappa']
        phi_np = phi.numpy() if torch.is_tensor(phi) else np.array(phi)
        gamma_np = gamma.numpy() if torch.is_tensor(gamma) else np.array(gamma)
        k = kappa.item() if hasattr(kappa, 'item') else float(kappa)
        print(f'{label}: phi {phi_np.shape}, gamma {gamma_np.shape}, '
              f'kappa {k:.4f}, mean|gamma| {np.abs(gamma_np).mean():.4f}, '
              f'n_batches {d.get("n_batches", "?")}') 

NOLR: phi (21, 348, 52), gamma (47, 21), kappa 2.9319, mean|gamma| 0.0057, n_batches 39
REPARAM: phi (21, 348, 52), gamma (47, 21), kappa 4.5186, mean|gamma| 0.0814, n_batches 39


---
## Step 3: Leave-One-Out (LOO) Prediction

For each of the first 5 prediction batches (50k patients total):
1. **Pool** phi, psi, kappa, gamma from all 40 training batches **except** batch i
2. Fix these LOO-pooled population params
3. Optimize only individual-level params: **lambda** (nolr) or **delta** (reparam)
4. Save pi tensors

This eliminates data leakage: the prediction batch's own trained params never leak into the
pooled population params used to generate its predictions.

Script: `run_loo_predict_both.py`  
- Loads all 40 nolr + 40 reparam training checkpoints  
- For each eval batch, computes LOO-pooled params (excluding that batch)  
- Fits nolr model (optimize lambda) and reparam model (optimize delta)  
- Saves pi to `enrollment_predictions_fixedphi_fixedgk_{nolr|reparam}_loo/`

In [None]:
# LOO prediction — both nolr and reparam, first 5 batches (50k patients)
# Skip if pi files already exist
import subprocess, sys

loo_nolr_exists = os.path.exists(LOO_NOLR_DIR + 'pi_enroll_fixedphi_sex_0_10000.pt')
loo_reparam_exists = os.path.exists(LOO_REPARAM_DIR + 'pi_enroll_fixedphi_sex_0_10000.pt')

if not (loo_nolr_exists and loo_reparam_exists):
    subprocess.run([sys.executable, f'{SCRIPT_DIR}/run_loo_predict_both.py',
        '--n_pred_batches', '5',
        '--n_train_batches', '40',
        '--num_epochs', '200',
        '--learning_rate', '0.1',
        '--data_dir', DATA_DIR,
        '--nolr_train_dir', NOLR_TRAIN_DIR,
        '--reparam_train_dir', REPARAM_TRAIN_DIR,
        '--covariates_path', COV_PATH,
        '--output_base', DROPBOX])
else:
    print(f'LOO predictions already exist:')
    print(f'  nolr:    {LOO_NOLR_DIR}')
    print(f'  reparam: {LOO_REPARAM_DIR}')

LOO predictions already exist:
  nolr:    /Users/sarahurbut/Library/CloudStorage/Dropbox/enrollment_predictions_fixedphi_fixedgk_nolr_loo/
  reparam: /Users/sarahurbut/Library/CloudStorage/Dropbox/enrollment_predictions_fixedphi_fixedgk_reparam_loo/


In [None]:
# Check LOO prediction outputs
import glob, torch
for label, d in [('LOO NOLR', LOO_NOLR_DIR), ('LOO REPARAM', LOO_REPARAM_DIR)]:
    pis = sorted(glob.glob(d + 'pi_enroll_fixedphi_sex_*_*.pt'))
    print(f'{label}: {len(pis)} pi files in {d}')
    for p in pis[:5]:
        pi = torch.load(p, map_location='cpu', weights_only=False)
        nan_count = torch.isnan(pi).sum().item()
        print(f'  {os.path.basename(p)}: {pi.shape}, NaN: {nan_count}')

LOO NOLR: 5 pi files in /Users/sarahurbut/Library/CloudStorage/Dropbox/enrollment_predictions_fixedphi_fixedgk_nolr_loo/
  pi_enroll_fixedphi_sex_0_10000.pt: torch.Size([10000, 348, 52]), NaN: 0
  pi_enroll_fixedphi_sex_10000_20000.pt: torch.Size([10000, 348, 52]), NaN: 0
  pi_enroll_fixedphi_sex_20000_30000.pt: torch.Size([10000, 348, 52]), NaN: 0
  pi_enroll_fixedphi_sex_30000_40000.pt: torch.Size([10000, 348, 52]), NaN: 0
  pi_enroll_fixedphi_sex_40000_50000.pt: torch.Size([10000, 348, 52]), NaN: 0
LOO REPARAM: 5 pi files in /Users/sarahurbut/Library/CloudStorage/Dropbox/enrollment_predictions_fixedphi_fixedgk_reparam_loo/
  pi_enroll_fixedphi_sex_0_10000.pt: torch.Size([10000, 348, 52]), NaN: 0
  pi_enroll_fixedphi_sex_10000_20000.pt: torch.Size([10000, 348, 52]), NaN: 0
  pi_enroll_fixedphi_sex_20000_30000.pt: torch.Size([10000, 348, 52]), NaN: 0
  pi_enroll_fixedphi_sex_30000_40000.pt: torch.Size([10000, 348, 52]), NaN: 0
  pi_enroll_fixedphi_sex_40000_50000.pt: torch.Size([10000

---
## Step 4: LOO AUC Evaluation

Compare AUC across 3 metrics (static 10yr, dynamic 10yr, dynamic 1yr) using **LOO predictions**
on the first 5 batches (50k patients), with 100 bootstrap resamples.

Script: `compare_loo_auc.py`  
- Concatenates LOO pi tensors for all 5 eval batches  
- Evaluates per-disease AUC with bootstrapped CIs  
- Also compares LOO AUC against non-LOO AUC (to confirm no data leakage)  
- Output: `nolr_vs_reparam_5batches_auc_LOO.csv`

In [None]:
# Run LOO AUC comparison (100 bootstraps, ~30-60 min)
auc_csv = SCRIPT_DIR + '/nolr_vs_reparam_5batches_auc_LOO.csv'
if not os.path.exists(auc_csv):
    subprocess.run([sys.executable, f'{SCRIPT_DIR}/compare_loo_auc.py',
        '--n_bootstraps', '100', '--n_batches', '5'])
else:
    print(f'LOO AUC results already exist: {auc_csv}')

LOO AUC results already exist: /Users/sarahurbut/aladynoulli2/claudefile/nolr_vs_reparam_5batches_auc_LOO.csv


In [None]:
# Display LOO AUC results
import pandas as pd

df = pd.read_csv(auc_csv)

for horizon in ['static_10yr', 'dynamic_10yr', 'dynamic_1yr']:
    sub = df[df['horizon'] == horizon].copy()
    sub['delta'] = sub['reparam_auc'] - sub['nolr_auc']
    
    nm = sub['nolr_auc'].mean()
    rm = sub['reparam_auc'].mean()
    rw = (sub['reparam_auc'] > sub['nolr_auc']).sum()
    nw = (sub['nolr_auc'] > sub['reparam_auc']).sum()
    
    print(f'\n{horizon.upper()} (LOO)')
    print(f'  Mean AUC -- nolr: {nm:.3f}, reparam: {rm:.3f}, delta: {rm-nm:+.3f}')
    print(f'  Wins -- nolr: {nw}, reparam: {rw}')
    print(f'  Top 5 reparam gains:')
    top = sub.nlargest(5, 'delta')[['disease', 'nolr_auc', 'reparam_auc', 'delta']]
    for _, r in top.iterrows():
        print(f'    {r["disease"]:<25} {r["nolr_auc"]:.3f} -> {r["reparam_auc"]:.3f} ({r["delta"]:+.3f})')


STATIC_10YR (LOO)
  Mean AUC -- nolr: 0.622, reparam: 0.653, delta: +0.031
  Wins -- nolr: 3, reparam: 25
  Top 5 reparam gains:
    Diabetes                  0.629 -> 0.716 (+0.087)
    Bipolar_Disorder          0.459 -> 0.544 (+0.086)
    Depression                0.478 -> 0.540 (+0.062)
    Heart_Failure             0.705 -> 0.763 (+0.058)
    Ulcerative_Colitis        0.570 -> 0.627 (+0.056)

DYNAMIC_10YR (LOO)
  Mean AUC -- nolr: 0.624, reparam: 0.627, delta: +0.003
  Wins -- nolr: 12, reparam: 16
  Top 5 reparam gains:
    Atrial_Fib                0.653 -> 0.721 (+0.068)
    Diabetes                  0.648 -> 0.712 (+0.064)
    Crohns_Disease            0.531 -> 0.583 (+0.052)
    Breast_Cancer             0.555 -> 0.603 (+0.048)
    Bladder_Cancer            0.720 -> 0.762 (+0.042)

DYNAMIC_1YR (LOO)
  Mean AUC -- nolr: 0.765, reparam: 0.882, delta: +0.117
  Wins -- nolr: 4, reparam: 24
  Top 5 reparam gains:
    Asthma                    0.668 -> 0.963 (+0.295)
    Depression

---
## Summary of Key Differences

| | Centered (nolr) | Non-Centered (reparam) |
|---|---|---|
| **Optimized param** | lambda (free) | delta (lambda = G@gamma + delta) |
| **gamma in forward pass** | No (prior only) | Yes (sets lambda mean) |
| **kappa** | 2.93 | 4.52 |
| **mean \|gamma\|** | 0.006 | 0.081 |
| **Prediction init** | lambda near 0 | lambda at G@gamma |

## LOO AUC Results (50k patients, 100 bootstraps)

| Horizon | nolr | reparam | delta | reparam wins |
|---|---|---|---|---|
| Static 10yr | 0.622 | 0.653 | +0.031 | 25/28 |
| Dynamic 10yr | 0.624 | 0.627 | +0.003 | 16/28 |
| Dynamic 1yr | 0.765 | 0.882 | +0.117 | 24/28 |

LOO vs non-LOO AUC difference < 0.002, confirming no data leakage.