# Nokappa Pipeline: Pool → LOO Predict → AUC Compare

Training is done (40 batches, 500 epochs each, cosine + clip, kappa=1 fixed).
Checkpoints in: `Dropbox/censor_e_batchrun_vectorized_REPARAM_v2_nokappa/`

## Steps
1. **Pool** phi, psi, gamma across 39 training batches (batch 40 = holdout)
2. **LOO Predict** on 5 eval batches — for each, pool from all EXCEPT that batch, optimize delta
3. **AUC Compare** nokappa vs v1 reparam vs nolr (3 horizons, 100 bootstraps)

## Step 1: Pool Nokappa Params

Script: `pool_phi_kappa_gamma_from_batches.py --model_type nokappa`

Output: `pooled_phi_kappa_gamma_nokappa.pt` in data_for_running/

In [None]:
import subprocess, sys, os

SCRIPT_DIR = '/Users/sarahurbut/aladynoulli2/claudefile'
DATA_DIR = '/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/'

subprocess.run([sys.executable, f'{SCRIPT_DIR}/pool_phi_kappa_gamma_from_batches.py',
    '--model_type', 'nokappa', '--max_batches', '39'])

In [None]:
# Inspect pooled params
import torch, numpy as np

pooled = torch.load(DATA_DIR + 'pooled_phi_kappa_gamma_nokappa.pt', weights_only=False)
phi = pooled['phi']
gamma = pooled['gamma']
kappa = pooled['kappa']
phi_np = phi if isinstance(phi, np.ndarray) else phi.numpy()
gamma_np = gamma if isinstance(gamma, np.ndarray) else gamma.numpy()
k = kappa if isinstance(kappa, float) else kappa.item()

print(f'phi: {phi_np.shape}')
print(f'gamma: {gamma_np.shape}')
print(f'kappa: {k:.4f} (should be ~1.0)')
print(f'mean|gamma|: {np.abs(gamma_np).mean():.4f}')
print(f'n_batches: {pooled.get("n_batches", "?")}')
if 'psi' in pooled:
    psi = pooled['psi']
    psi_np = psi if isinstance(psi, np.ndarray) else psi.numpy()
    print(f'psi: {psi_np.shape}')

## Step 2: LOO Prediction (5 batches, 50k patients)

Script: `run_loo_predict_nokappa.py`

For each eval batch i:
- Pool phi/psi/gamma from all 40 checkpoints EXCEPT batch i
- Fix those params, optimize delta only (200 epochs, lr=0.1)
- Save pi tensor

Output: `Dropbox/enrollment_predictions_fixedphi_fixedgk_nokappa_loo/`

In [None]:
subprocess.run([sys.executable, f'{SCRIPT_DIR}/run_loo_predict_nokappa.py',
    '--n_pred_batches', '5',
    '--n_train_batches', '40',
    '--num_epochs', '200',
    '--learning_rate', '0.1'])

In [None]:
# Check LOO output
import glob
LOO_DIR = '/Users/sarahurbut/Library/CloudStorage/Dropbox/enrollment_predictions_fixedphi_fixedgk_nokappa_loo/'
pis = sorted(glob.glob(LOO_DIR + 'pi_enroll_fixedphi_sex_*_*.pt'))
print(f'{len(pis)} pi files:')
for p in pis:
    pi = torch.load(p, map_location='cpu', weights_only=False)
    nan_count = torch.isnan(pi).sum().item()
    print(f'  {os.path.basename(p)}: {pi.shape}, NaN: {nan_count}')

## Step 3: AUC Comparison (nokappa vs v1 reparam vs nolr)

Script: `compare_nokappa_auc.py`

Evaluates 3 horizons (static 10yr, dynamic 10yr, dynamic 1yr) with 100 bootstraps.
Compares against existing LOO results from `nolr_vs_reparam_5batches_auc_LOO.csv`.

Output: `nokappa_auc_LOO.csv`

In [None]:
subprocess.run([sys.executable, f'{SCRIPT_DIR}/compare_nokappa_auc.py',
    '--n_bootstraps', '100', '--n_batches', '5'])

In [None]:
# Display results
import pandas as pd

nk = pd.read_csv(f'{SCRIPT_DIR}/nokappa_auc_LOO.csv')
old = pd.read_csv(f'{SCRIPT_DIR}/nolr_vs_reparam_5batches_auc_LOO.csv')

for horizon in ['static_10yr', 'dynamic_10yr', 'dynamic_1yr']:
    nk_h = nk[nk['horizon'] == horizon].set_index('disease')
    old_h = old[old['horizon'] == horizon].set_index('disease')
    diseases = sorted(set(nk_h.index) & set(old_h.index))
    
    nolr_m = old_h.loc[diseases, 'nolr_auc'].mean()
    v1_m = old_h.loc[diseases, 'reparam_auc'].mean()
    nk_m = nk_h.loc[diseases, 'auc'].mean()
    
    nk_wins_nolr = sum(nk_h.loc[d, 'auc'] > old_h.loc[d, 'nolr_auc'] for d in diseases)
    nk_wins_v1 = sum(nk_h.loc[d, 'auc'] > old_h.loc[d, 'reparam_auc'] for d in diseases)
    
    print(f'\n{horizon.upper()}')
    print(f'  nolr: {nolr_m:.3f}  v1_reparam: {v1_m:.3f}  nokappa: {nk_m:.3f}')
    print(f'  nokappa vs nolr: {nk_m-nolr_m:+.3f} (wins {nk_wins_nolr}/{len(diseases)})')
    print(f'  nokappa vs v1:   {nk_m-v1_m:+.3f} (wins {nk_wins_v1}/{len(diseases)})')