# ============================================================================
# RESULTS GENERATION DOCUMENTATION
# ============================================================================
"""
Time Horizon Predictions (5yr, 10yr, 30yr, static 10yr):
--------------------------------------------------------
Generated using: scripts/generate_time_horizon_predictions.py

This script processes ALL patients at once using pre-computed pi tensors:
- Uses evaluate_major_diseases_wsex_with_bootstrap_dynamic_from_pi() for dynamic predictions
- Uses evaluate_major_diseases_wsex_with_bootstrap_from_pi() for static predictions
- Computes AUC on pooled predictions (statistically better than batch-averaging)

Approaches:
- Pooled Enrollment: pi from enrollment_predictions_fixedphi_ENROLLMENT_pooled/pi_enroll_fixedphi_sex_FULL.pt
- Pooled Retrospective: pi from enrollment_predictions_fixedphi_RETROSPECTIVE_pooled/pi_enroll_fixedphi_sex_FULL.pt

Results saved to: results/time_horizons/{approach}/

Washout Predictions (1-year with 0yr, 1yr, 2yr offsets):
---------------------------------------------------------
Generated using: scripts/generate_washout_predictions.py

This script processes ALL patients at once using pre-computed pi tensors:
- Uses evaluate_major_diseases_wsex_with_bootstrap_dynamic_1year_different_start_end_numeric_sex()
- Computes AUC on pooled predictions

Results saved to: results/washout/{approach}/
"""

In [9]:

# ============================================================================
# STEP 0: ASSEMBLE FULL PI TENSORS (RUN ONCE, THEN MARK AS "NOT EVALUATED")
# ============================================================================
"""
IMPORTANT: This cell assembles batch pi tensors into full pi tensors.
- Run this ONCE before running the generation cells
- After assembly is complete, mark this cell as "not evaluated"
- This creates pi_enroll_fixedphi_sex_FULL.pt files needed by the generation scripts
"""

import subprocess
import sys
from pathlib import Path

# Set script directory
script_dir = Path('/Users/sarahurbut/aladynoulli2/pyScripts/new_oct_revision/new_notebooks')

print("="*80)
print("ASSEMBLING FULL PI TENSORS FROM BATCH FILES")
print("="*80)
print("\nThis will concatenate all batch pi tensors (0-10000, 10000-20000, ..., 390000-400000)")
print("into single full tensors for 0-400K patients.")
print("\nNOTE: Run once, then mark this cell as 'not evaluated'.")
print("="*80)

# Assemble retrospective pooled pi tensor
print("\n1. Assembling pooled_retrospective pi tensor...")
result1 = subprocess.run([
    sys.executable,
    str(script_dir / 'assemble_full_pi_tensor.py'),
    '--approach', 'pooled_retrospective',
    '--max_patients', '400000'
], capture_output=True, text=True)
print(result1.stdout)
if result1.stderr:
    print("STDERR:", result1.stderr)
if result1.returncode != 0:
    print(f"ERROR: Assembly failed with return code {result1.returncode}")

# Assemble enrollment pooled pi tensor
print("\n2. Assembling pooled_enrollment pi tensor...")
result2 = subprocess.run([
    sys.executable,
    str(script_dir / 'assemble_full_pi_tensor.py'),
    '--approach', 'pooled_enrollment',
    '--max_patients', '400000'
], capture_output=True, text=True)
print(result2.stdout)
if result2.stderr:
    print("STDERR:", result2.stderr)
if result2.returncode != 0:
    print(f"ERROR: Assembly failed with return code {result2.returncode}")

print("\n" + "="*80)
print("PI TENSOR ASSEMBLY COMPLETE")
print("="*80)
print("\nFull pi tensors should now be available at:")
print("  - enrollment_predictions_fixedphi_RETROSPECTIVE_pooled/pi_enroll_fixedphi_sex_FULL.pt")
print("  - enrollment_predictions_fixedphi_ENROLLMENT_pooled/pi_enroll_fixedphi_sex_FULL.pt")



ASSEMBLING FULL PI TENSORS FROM BATCH FILES

This will concatenate all batch pi tensors (0-10000, 10000-20000, ..., 390000-400000)
into single full tensors for 0-400K patients.

NOTE: Run once, then mark this cell as 'not evaluated'.

1. Assembling pooled_retrospective pi tensor...
ASSEMBLING FULL PI TENSOR: POOLED_RETROSPECTIVE
Base directory: /Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/enrollment_predictions_fixedphi_RETROSPECTIVE_pooled
Max patients: 400000
Batch size: 10000

Will assemble 40 batches (0-400000)
Loading batch 1/40: 0-10000... ✓ Shape: torch.Size([10000, 348, 52])
Loading batch 2/40: 10000-20000... ✓ Shape: torch.Size([10000, 348, 52])
Loading batch 3/40: 20000-30000... ✓ Shape: torch.Size([10000, 348, 52])
Loading batch 4/40: 30000-40000... ✓ Shape: torch.Size([10000, 348, 52])
Loading batch 5/40: 40000-50000... ✓ Shape: torch.Size([10000, 348, 52])
Loading batch 6/40: 50000-60000... ✓ Shape: torch.Size([10000, 348, 52])
Loading batch 7/40: 60000-70000...

In [None]:
# ============================================================================
# STEP 1: GENERATE RESULTS (RUN ONCE, THEN MARK AS "NOT EVALUATED")
# ============================================================================
"""
IMPORTANT: These cells generate the results CSV files.
- Run them ONCE to generate all results
- After results are generated, mark these cells as "not evaluated" to prevent re-running
- The results will be saved to results/time_horizons/ and results/washout/
"""

import subprocess
import sys
from pathlib import Path

# Set script directory
script_dir = Path('/Users/sarahurbut/aladynoulli2/pyScripts/new_oct_revision/new_notebooks')

print("="*80)
print("GENERATING TIME HORIZON PREDICTIONS")
print("="*80)
print("\nThis will generate 5yr, 10yr, 30yr, and static 10yr predictions")
print("for both pooled_enrollment and pooled_retrospective approaches.")
print("\nNOTE: This takes a while! Run once, then mark this cell as 'not evaluated'.")
print("="*80)

# Generate time horizon predictions for pooled retrospective (main approach)
print("\n1. Generating pooled_retrospective time horizons...")
result1 = subprocess.run([
    sys.executable, 
    str(script_dir / 'generate_time_horizon_predictions.py'),
    '--approach', 'pooled_retrospective',
    '--horizons', '5,10,30,static10',
    '--n_bootstraps', '100'
], capture_output=True, text=True)
print(result1.stdout)
if result1.stderr:
    print("STDERR:", result1.stderr)
if result1.returncode != 0:
    print(f"\n⚠️  WARNING: Script exited with return code {result1.returncode}")
else:
    print("✓ pooled_retrospective completed successfully")

# Generate time horizon predictions for pooled enrollment (for comparison)
print("\n2. Generating pooled_enrollment time horizons...")
result2 = subprocess.run([
    sys.executable,
    str(script_dir / 'generate_time_horizon_predictions.py'),
    '--approach', 'pooled_enrollment',
    '--horizons', '5,10,30,static10',
    '--n_bootstraps', '100'
], capture_output=True, text=True)
print(result2.stdout)
if result2.stderr:
    print("STDERR:", result2.stderr)
if result2.returncode != 0:
    print(f"\n⚠️  WARNING: Script exited with return code {result2.returncode}")
else:
    print("✓ pooled_enrollment completed successfully")

print("\n" + "="*80)
print("TIME HORIZON PREDICTIONS COMPLETE")
print("="*80)


GENERATING TIME HORIZON PREDICTIONS

This will generate 5yr, 10yr, 30yr, and static 10yr predictions
for both pooled_enrollment and pooled_retrospective approaches.

NOTE: This takes a while! Run once, then mark this cell as 'not evaluated'.

1. Generating pooled_retrospective time horizons...


In [None]:
# ============================================================================
# STEP 2: GENERATE WASHOUT PREDICTIONS (RUN ONCE, THEN MARK AS "NOT EVALUATED")
# ============================================================================

print("="*80)
print("GENERATING WASHOUT PREDICTIONS")
print("="*80)
print("\nThis will generate 1-year predictions with 0yr, 1yr, 2yr washout")
print("for both pooled_enrollment and pooled_retrospective approaches.")
print("\nNOTE: This takes a while! Run once, then mark this cell as 'not evaluated'.")
print("="*80)

# Generate washout predictions for pooled retrospective (main approach)
print("\n1. Generating pooled_retrospective washout predictions...")
result1 = subprocess.run([
    sys.executable,
    str(script_dir / 'generate_washout_predictions.py'),
    '--approach', 'pooled_retrospective',
    '--n_bootstraps', '100'
], capture_output=True, text=True)
print(result1.stdout)
if result1.stderr:
    print("STDERR:", result1.stderr)

# Generate washout predictions for pooled enrollment (for comparison)
print("\n2. Generating pooled_enrollment washout predictions...")
result2 = subprocess.run([
    sys.executable,
    str(script_dir / 'generate_washout_predictions.py'),
    '--approach', 'pooled_enrollment',
    '--n_bootstraps', '100'
], capture_output=True, text=True)
print(result2.stdout)
if result2.stderr:
    print("STDERR:", result2.stderr)

print("\n" + "="*80)
print("WASHOUT PREDICTIONS COMPLETE")
print("="*80)

In [None]:
# ============================================================================
# STEP 3: LOAD GENERATED RESULTS
# ============================================================================
"""
After running the generation cells above, load the results here for analysis.
"""

import pandas as pd
import numpy as np
from pathlib import Path

results_base = Path('/Users/sarahurbut/aladynoulli2/pyScripts/new_oct_revision/new_notebooks/results')

# Load time horizon results
print("Loading time horizon results...")
time_horizon_results = {}
for approach in ['pooled_retrospective', 'pooled_enrollment']:
    approach_dir = results_base / 'time_horizons' / approach
    if approach_dir.exists():
        time_horizon_results[approach] = {}
        for horizon_file in approach_dir.glob('*_results.csv'):
            horizon_name = horizon_file.stem.replace('_results', '')
            time_horizon_results[approach][horizon_name] = pd.read_csv(horizon_file, index_col=0)
            print(f"  ✓ Loaded {approach}/{horizon_name}")
        # Also load comparison file if exists
        comparison_file = approach_dir / 'comparison_all_horizons.csv'
        if comparison_file.exists():
            time_horizon_results[approach]['comparison'] = pd.read_csv(comparison_file, index_col=0)
            print(f"  ✓ Loaded {approach}/comparison")

# Load washout results
print("\nLoading washout results...")
washout_results = {}
for approach in ['pooled_retrospective', 'pooled_enrollment']:
    approach_dir = results_base / 'washout' / approach
    if approach_dir.exists():
        washout_results[approach] = {}
        for washout_file in approach_dir.glob('washout_*_results.csv'):
            washout_name = washout_file.stem.replace('washout_', '').replace('_results', '')
            washout_results[approach][washout_name] = pd.read_csv(washout_file, index_col=0)
            print(f"  ✓ Loaded {approach}/{washout_name}")
        # Also load comparison file if exists
        comparison_file = approach_dir / 'washout_comparison_all_offsets.csv'
        if comparison_file.exists():
            washout_results[approach]['comparison'] = pd.read_csv(comparison_file, index_col=0)
            print(f"  ✓ Loaded {approach}/comparison")

print("\n" + "="*80)
print("RESULTS LOADED - READY FOR ANALYSIS")
print("="*80)


# evaluate the 1 year performance
* at year 0
* with sliding windows (i.e., between years 41 and 42 or 42 and 43 calculated at year 40, but now using the score for those years)
# evaluate 10 year (1-surv^10), 30 year (1-surv^30), and 10 year (with 1 year preidction) with and wihtout washout
* compare to delphi 
* compare using batched from enrollment and batched from retospective pooled

Phase 1: Document what you have (30 minutes)
Create a simple RESULTS_MANIFEST.md that lists:
What each notebook generates
Where the outputs are saved
What each CSV file contains
Key parameters used
This gives you a map without moving files.

In [None]:
# Washout analysis

in washout_analysis_summary.ipynb , the code below generates one year predictions at time of prediciton (t0-1), for the year +1-2, and year +2-3 for model trained AT ENROLLMENT


# Load the full data once

fh_processed = pd.read_csv('/Users/sarahurbut/Library/Cloudstorage/Dropbox-Personal/baselinagefamh.csv')
from evaluatetdccode import *
# Define all batches (0-400K in 10K increments)
batches = [(i, i+10000) for i in range(0, 400000, 10000)]
print(f"\n2. PROCESSING {len(batches)} BATCHES")
print(f"Batches: {batches[:5]}...{batches[-5:]}")
# Define batches (same as training)
# Storage for results
washout_results = {
    '0yr': {},  # No washout
    '1yr': {},  # 1-year washout  
    '2yr': {}   # 2-year washout
}

# Run washout analysis on each batch
for start, stop in batches:
    print(f"\n=== Processing batch {start}-{stop} ===")
    
    # Load batch predictions
    #pi_filename = f"/Users/sarahurbut/Library/CloudStorage/Dropbox/enrollment_predictions_fixedphi/pi_enroll_fixedphi_sex_{start}_{stop}.pt"
    pi_filename = f"/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/enrollment_predictions_fixedphi_RETROSPECTIVE_pooled/pi_enroll_fixedphi_sex_{start}_{stop}.pt"
   
   
    #m=torch.load(f"/Users/sarahurbut/aladynoulli2/claudefile/output/model_enroll_fixedphi_sex_{start}_{stop}.pt")
    m=torch.load(f"/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/enrollment_predictions_fixedphi_RETROSPECTIVE_pooled/model_enroll_fixedphi_sex_{start}_{stop}.pt")

    print(m['model_state_dict']['gamma'].shape)
    pi_batch = torch.load(pi_filename)
    
    # Subset other data to match
    Y_batch = Y[start:stop]
    E_batch = E[start:stop] 
    pce_df_batch = fh_processed.iloc[start:stop].reset_index(drop=True)
    
    # Run washout analysis for this batch
    for washout_name, offset in [('0yr', 0), ('1yr', 1), ('2yr', 2)]:
        print(f"  Running {washout_name} washout...")
        
        results = evaluate_major_diseases_wsex_with_bootstrap_dynamic_1year_different_start_end_numeric_sex(
            pi=pi_batch,
            Y_100k=Y_batch,
            E_100k=E_batch,
            disease_names=essentials['disease_names'],
            pce_df=pce_df_batch,
            n_bootstraps=50,  # Fewer bootstraps per batch
            follow_up_duration_years=1,
            start_offset=offset
        )
        
        # Store results
        for disease, metrics in results.items():
            if disease not in washout_results[washout_name]:
                washout_results[washout_name][disease] = {
                    'aucs': [], 'cis': [], 'events': [], 'rates': []
                }
            
            washout_results[washout_name][disease]['aucs'].append(metrics['auc'])
            washout_results[washout_name][disease]['cis'].append((metrics['ci_lower'], metrics['ci_upper']))
            washout_results[washout_name][disease]['events'].append(metrics['n_events'])
            washout_results[washout_name][disease]['rates'].append(metrics['event_rate'])
    
    # Clean up memory
    del pi_batch, Y_batch, E_batch, pce_df_batch
    torch.cuda.empty_cache() if torch.cuda.is_available() else None

# Aggregate results across batches
print("\n=== AGGREGATED WASHOUT RESULTS ===")
for washout_name, diseases in washout_results.items():
    print(f"\n{washout_name.upper()} WASHOUT:")
    for disease, metrics in diseases.items():
        aucs = [a for a in metrics['aucs'] if not pd.isna(a)]
        if aucs:
            mean_auc = np.mean(aucs)
            print(f"  {disease}: {mean_auc:.3f} (from {len(aucs)} batches)")

In [None]:
# in compare_offset.ipynb

run on AWS using  aws_offsetmaster/forAWS_offsetmasterfix.py, log in age_offset_files_aws.log

scp -i "/Users/sarahurbut/Downloads/sarahkey.pem" \
    ec2-user@ec2-3-81-0-40.compute-1.amazonaws.com:~/aladyn_project/output/age_offset_files.tar.gz \
    ~/Downloads/
("noullitwo")


and then compare_offset.ipynb generates ROC for model trained at enrollment (for batch 1) plus 1 year, 2 year, 3 ... 10 (so tretarined)

years_to_use = 10
disease_names = essentials['disease_names']

enrollment_ages = pce_df['age'].to_numpy()

# Load all batch predictions into a list
pi_batches = [
   torch.load(f"/Users/sarahurbut/Library/CloudStorage/Dropbox/pi_offset_using_pooled_retrospective_local/pi_enroll_fixedphi_age_offset_{k}_sex_0_10000_try2_withpcs_newrun_pooledall.pt")
   for k in range(years_to_use)
]


from evaluatetdccode import *
results = evaluate_major_diseases_rolling_1year_roc_curves(
    pi_batches, Y_100k, E_100k, disease_names, pce_df, plot_group='ASCVD'
)

In [None]:
├── time_horizons/    # 10yr, 30yr, static 10yr
│   ├── pooled_enrollment/
│   ├── pooled_retrospective/
│   └── joint_phi/


in lifetime.ipynb, we calcuate 10,30 and static 10 year and save to saved_results with the torch for each batch and pooled_comparison_all_approaches.csv, code used to generate is here (presumably we should just do for all Pi which are  in 

/Users/sarahurbut/Library/CloudStorage/Dropbox/enrollment_predictions_fixedphi_ENROLLMENT_pooled/pi_enroll_fixedphi_sex_FULL.pt

and 


/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/enrollment_predictions_fixedphi_RETROSPECTIVE_pooled/pi_enroll_fixedphi_sex_FULL.pt


one thing i should note: the model that these fill in the model.load_state_dict can also be generated with quick_model_dummy.py but presumably we should use the pis directly ...


from fig5utils import *
import pandas as pd
import numpy as np
import torch

# Load full pce_df
pce_df_full = pd.read_csv('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/pce_prevent_full.csv')
disease_names = essentials['disease_names']

# Storage for results - SEPARATE variables for each analysis type
# Fixed phi from ENROLLMENT data
fixed_enrollment_10yr_results = []
fixed_enrollment_30yr_results = []
fixed_enrollment_static_10yr_results = []

# Fixed phi from RETROSPECTIVE data
fixed_retrospective_10yr_results = []
fixed_retrospective_30yr_results = []
fixed_retrospective_static_10yr_results = []

# Load full tensors once (shared across both analyses)
if 'Y_full' not in globals():
    Y_full = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/Y_tensor.pt')
if 'E_full' not in globals():
    E_full = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/E_enrollment_full.pt')

# Loop through checkpoints 0-10 (10 batches)
for batch_idx in range(41):
    start_idx = batch_idx * 10000
    end_idx = (batch_idx + 1) * 10000
    
    print(f"\n{'='*80}")
    print(f"Processing batch {batch_idx}: {start_idx} to {end_idx}")
    print(f"{'='*80}")
    
    # Get pce_df subset for this batch
    pce_df_subset = pce_df_full[start_idx:end_idx].copy().reset_index(drop=True)
    
    # Extract batch from full tensors (shared for both analyses)
    Y_batch = Y_full[start_idx:end_idx]
    E_batch = E_full[start_idx:end_idx]
    
    # ===== FIXED PHI FROM ENROLLMENT DATA =====
    fixed_enrollment_ckpt_path = f'/Users/sarahurbut/Library/CloudStorage/Dropbox/enrollment_predictions_fixedphi_ENROLLMENT_pooled/model_enroll_fixedphi_sex_{start_idx}_{end_idx}.pt'
    
    try:
        print(f"\n--- Fixed Phi (ENROLLMENT) ---")
        fixed_ckpt = torch.load(fixed_enrollment_ckpt_path, weights_only=False)
        model.load_state_dict(fixed_ckpt['model_state_dict'])
        
        # Update model.Y and model.N so forward() uses correct patients
        model.Y = torch.tensor(Y_batch, dtype=torch.float32)
        model.N = Y_batch.shape[0]
       
        # 10-year predictions
        print(f"Fixed Phi (ENROLLMENT) - 10 year predictions...")
        fixed_10yr = evaluate_major_diseases_wsex_with_bootstrap_dynamic(
            model, Y_batch, E_batch, disease_names, pce_df_subset, 
            n_bootstraps=100, follow_up_duration_years=10, patient_indices=None
        )
        fixed_10yr['batch_idx'] = batch_idx
        fixed_10yr['analysis_type'] = 'fixed_enrollment'
        fixed_enrollment_10yr_results.append(fixed_10yr)
        
        # 30-year predictions
        print(f"Fixed Phi (ENROLLMENT) - 30 year predictions...")
        fixed_30yr = evaluate_major_diseases_wsex_with_bootstrap_dynamic(
            model, Y_batch, E_batch, disease_names, pce_df_subset, 
            n_bootstraps=100, follow_up_duration_years=30, patient_indices=None
        )
        fixed_30yr['batch_idx'] = batch_idx
        fixed_30yr['analysis_type'] = 'fixed_enrollment'
        fixed_enrollment_30yr_results.append(fixed_30yr)
        
        # Static 10-year predictions (using 1-year score)
        print(f"Fixed Phi (ENROLLMENT) - Static 10 year predictions...")
        fixed_static_10yr = evaluate_major_diseases_wsex_with_bootstrap(
            model=model,
            Y_100k=Y_batch,
            E_100k=E_batch,
            disease_names=disease_names,
            pce_df=pce_df_subset,
            n_bootstraps=100,
            follow_up_duration_years=10,
        )
        fixed_static_10yr['batch_idx'] = batch_idx
        fixed_static_10yr['analysis_type'] = 'fixed_enrollment'
        fixed_enrollment_static_10yr_results.append(fixed_static_10yr)
        
    except FileNotFoundError:
        print(f"Fixed phi (ENROLLMENT) checkpoint not found: {fixed_enrollment_ckpt_path}")
    except Exception as e:
        print(f"Error processing fixed phi (ENROLLMENT) checkpoint {batch_idx}: {e}")
        import traceback
        traceback.print_exc()
    
    # ===== FIXED PHI FROM RETROSPECTIVE DATA =====
    fixed_retrospective_ckpt_path = f'/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/enrollment_predictions_fixedphi_RETROSPECTIVE_pooled/model_enroll_fixedphi_sex_{start_idx}_{end_idx}.pt'
    
    try:
        print(f"\n--- Fixed Phi (RETROSPECTIVE) ---")
        fixed_ckpt = torch.load(fixed_retrospective_ckpt_path, weights_only=False)
        model.load_state_dict(fixed_ckpt['model_state_dict'])
        
        # Update model.Y and model.N so forward() uses correct patients
        model.Y = torch.tensor(Y_batch, dtype=torch.float32)
        model.N = Y_batch.shape[0]
       
        # 10-year predictions
        print(f"Fixed Phi (RETROSPECTIVE) - 10 year predictions...")
        fixed_10yr = evaluate_major_diseases_wsex_with_bootstrap_dynamic(
            model, Y_batch, E_batch, disease_names, pce_df_subset, 
            n_bootstraps=100, follow_up_duration_years=10, patient_indices=None
        )
        fixed_10yr['batch_idx'] = batch_idx
        fixed_10yr['analysis_type'] = 'fixed_retrospective'
        fixed_retrospective_10yr_results.append(fixed_10yr)
        
        # 30-year predictions
        print(f"Fixed Phi (RETROSPECTIVE) - 30 year predictions...")
        fixed_30yr = evaluate_major_diseases_wsex_with_bootstrap_dynamic(
            model, Y_batch, E_batch, disease_names, pce_df_subset, 
            n_bootstraps=100, follow_up_duration_years=30, patient_indices=None
        )
        fixed_30yr['batch_idx'] = batch_idx
        fixed_30yr['analysis_type'] = 'fixed_retrospective'
        fixed_retrospective_30yr_results.append(fixed_30yr)
        
        # Static 10-year predictions (using 1-year score)
        print(f"Fixed Phi (RETROSPECTIVE) - Static 10 year predictions...")
        fixed_static_10yr = evaluate_major_diseases_wsex_with_bootstrap(
            model=model,
            Y_100k=Y_batch,
            E_100k=E_batch,
            disease_names=disease_names,
            pce_df=pce_df_subset,
            n_bootstraps=100,
            follow_up_duration_years=10,
        )
        fixed_static_10yr['batch_idx'] = batch_idx
        fixed_static_10yr['analysis_type'] = 'fixed_retrospective'
        fixed_retrospective_static_10yr_results.append(fixed_static_10yr)
        
    except FileNotFoundError:
        print(f"Fixed phi (RETROSPECTIVE) checkpoint not found: {fixed_retrospective_ckpt_path}")
    except Exception as e:
        print(f"Error processing fixed phi (RETROSPECTIVE) checkpoint {batch_idx}: {e}")
        import traceback
        traceback.print_exc()

print(f"\n{'='*80}")
print("Completed processing all checkpoints!")
print(f"{'='*80}")
print(f"Fixed Enrollment - 10yr: {len(fixed_enrollment_10yr_results)} batches")
print(f"Fixed Enrollment - 30yr: {len(fixed_enrollment_30yr_results)} batches")
print(f"Fixed Retrospective - 10yr: {len(fixed_retrospective_10yr_results)} batches")
print(f"Fixed Retrospective - 30yr: {len(fixed_retrospective_30yr_results)} batches")

# ===== SAVE RESULTS TO DISK (to avoid rerunning long computation) =====
# Paste this cell right after line 157 (after the print statements)

print(f"\n{'='*80}")
print("SAVING RESULTS TO DISK")
print(f"{'='*80}")

results_dir = '/Users/sarahurbut/aladynoulli2/pyScripts/new_oct_revision/new_notebooks/saved_results/'
import os
os.makedirs(results_dir, exist_ok=True)

# Save all 6 result lists
print("\nSaving Fixed Enrollment results...")
torch.save(fixed_enrollment_10yr_results, f'{results_dir}fixed_enrollment_10yr_results.pt')
torch.save(fixed_enrollment_30yr_results, f'{results_dir}fixed_enrollment_30yr_results.pt')
torch.save(fixed_enrollment_static_10yr_results, f'{results_dir}fixed_enrollment_static_10yr_results.pt')

print("Saving Fixed Retrospective results...")
torch.save(fixed_retrospective_10yr_results, f'{results_dir}fixed_retrospective_10yr_results.pt')
torch.save(fixed_retrospective_30yr_results, f'{results_dir}fixed_retrospective_30yr_results.pt')
torch.save(fixed_retrospective_static_10yr_results, f'{results_dir}fixed_retrospective_static_10yr_results.pt')

print(f"\n✓ All results saved to {results_dir}")
print(f"  - fixed_enrollment_10yr_results.pt ({len(fixed_enrollment_10yr_results)} batches)")
print(f"  - fixed_enrollment_30yr_results.pt ({len(fixed_enrollment_30yr_results)} batches)")
print(f"  - fixed_enrollment_static_10yr_results.pt ({len(fixed_enrollment_static_10yr_results)} batches)")
print(f"  - fixed_retrospective_10yr_results.pt ({len(fixed_retrospective_10yr_results)} batches)")
print(f"  - fixed_retrospective_30yr_results.pt ({len(fixed_retrospective_30yr_results)} batches)")
print(f"  - fixed_retrospective_static_10yr_results.pt ({len(fixed_retrospective_static_10yr_results)} batches)")

print(f"\n{'='*80}")
print("To reload later, use:")
print(f"{'='*80}")
print("fixed_enrollment_10yr_results = torch.load('{results_dir}fixed_enrollment_10yr_results.pt')")
print("fixed_enrollment_30yr_results = torch.load('{results_dir}fixed_enrollment_30yr_results.pt')")
print("fixed_enrollment_static_10yr_results = torch.load('{results_dir}fixed_enrollment_static_10yr_results.pt')")
print("fixed_retrospective_10yr_results = torch.load('{results_dir}fixed_retrospective_10yr_results.pt')")
print("fixed_retrospective_30yr_results = torch.load('{results_dir}fixed_retrospective_30yr_results.pt')")
print("fixed_retrospective_static_10yr_results = torch.load('{results_dir}fixed_retrospective_static_10yr_results.pt')")


# ===== AGGREGATE AND SAVE RESULTS =====
print(f"\n{'='*80}")
print("AGGREGATING RESULTS ACROSS ALL BATCHES")
print(f"{'='*80}")

def aggregate_results_to_dataframe(results_list, analysis_name):
    """
    Aggregate results across batches into a DataFrame.
    Each result is a dict with disease names as keys and metrics as values.
    """
    if not results_list:
        print(f"Warning: No results found for {analysis_name}")
        return pd.DataFrame()
    
    # Get all disease names (excluding metadata keys)
    disease_names_list = [k for k in results_list[0].keys() 
                         if k not in ['batch_idx', 'analysis_type']]
    
    # Collect all metrics across batches
    aggregated_data = []
    for disease in disease_names_list:
        aucs = []
        ci_lowers = []
        ci_uppers = []
        n_events_list = []
        event_rates = []
        
        for result in results_list:
            if disease in result and isinstance(result[disease], dict):
                if 'auc' in result[disease] and not np.isnan(result[disease]['auc']):
                    aucs.append(result[disease]['auc'])
                if 'ci_lower' in result[disease] and not np.isnan(result[disease]['ci_lower']):
                    ci_lowers.append(result[disease]['ci_lower'])
                if 'ci_upper' in result[disease] and not np.isnan(result[disease]['ci_upper']):
                    ci_uppers.append(result[disease]['ci_upper'])
                if 'n_events' in result[disease]:
                    n_events_list.append(result[disease]['n_events'])
                if 'event_rate' in result[disease] and result[disease]['event_rate'] is not None:
                    event_rates.append(result[disease]['event_rate'])
        
        if aucs:  # Only add if we have at least one valid AUC
            aggregated_data.append({
                'Disease': disease,
                'AUC_median': np.median(aucs),
                'AUC_mean': np.mean(aucs),
                'AUC_std': np.std(aucs),
                'AUC_min': np.min(aucs),
                'AUC_max': np.max(aucs),
                'CI_lower_median': np.median(ci_lowers) if ci_lowers else np.nan,
                'CI_upper_median': np.median(ci_uppers) if ci_uppers else np.nan,
                'CI_lower_min': np.min(ci_lowers) if ci_lowers else np.nan,
                'CI_upper_max': np.max(ci_uppers) if ci_uppers else np.nan,
                'Total_Events': np.sum(n_events_list) if n_events_list else np.nan,
                'Mean_Event_Rate': np.mean(event_rates) if event_rates else np.nan,
                'N_Batches': len(aucs)
            })
    
    df = pd.DataFrame(aggregated_data)
    if not df.empty:
        df = df.set_index('Disease').sort_values('AUC_median', ascending=False)
    return df

# Aggregate all 6 result lists
print("\nAggregating Fixed Enrollment results...")
fixed_enrollment_10yr_df = aggregate_results_to_dataframe(fixed_enrollment_10yr_results, "Fixed Enrollment 10yr")
fixed_enrollment_30yr_df = aggregate_results_to_dataframe(fixed_enrollment_30yr_results, "Fixed Enrollment 30yr")
fixed_enrollment_static_10yr_df = aggregate_results_to_dataframe(fixed_enrollment_static_10yr_results, "Fixed Enrollment Static 10yr")

print("Aggregating Fixed Retrospective results...")
fixed_retrospective_10yr_df = aggregate_results_to_dataframe(fixed_retrospective_10yr_results, "Fixed Retrospective 10yr")
fixed_retrospective_30yr_df = aggregate_results_to_dataframe(fixed_retrospective_30yr_results, "Fixed Retrospective 30yr")
fixed_retrospective_static_10yr_df = aggregate_results_to_dataframe(fixed_retrospective_static_10yr_results, "Fixed Retrospective Static 10yr")

# Save individual DataFrames
output_dir = '/Users/sarahurbut/aladynoulli2/pyScripts/new_oct_revision/new_notebooks/'
print(f"\nSaving aggregated results to {output_dir}...")

fixed_enrollment_10yr_df.to_csv(f'{output_dir}pooled_fixed_enrollment_10yr.csv')
fixed_enrollment_30yr_df.to_csv(f'{output_dir}pooled_fixed_enrollment_30yr.csv')
fixed_enrollment_static_10yr_df.to_csv(f'{output_dir}pooled_fixed_enrollment_static_10yr.csv')

fixed_retrospective_10yr_df.to_csv(f'{output_dir}pooled_fixed_retrospective_10yr.csv')
fixed_retrospective_30yr_df.to_csv(f'{output_dir}pooled_fixed_retrospective_30yr.csv')
fixed_retrospective_static_10yr_df.to_csv(f'{output_dir}pooled_fixed_retrospective_static_10yr.csv')

print("✓ Saved individual result files")

# Create a combined comparison DataFrame (similar to comparison_all_approaches format)
print("\nCreating combined comparison DataFrame...")
all_diseases = set()
for df in [fixed_enrollment_10yr_df, fixed_enrollment_30yr_df, fixed_retrospective_10yr_df, 
           fixed_retrospective_30yr_df, fixed_enrollment_static_10yr_df, fixed_retrospective_static_10yr_df]:
    if not df.empty:
        all_diseases.update(df.index)

comparison_df = pd.DataFrame(index=sorted(all_diseases))
comparison_df['Fixed_Enrollment_10yr'] = fixed_enrollment_10yr_df['AUC_median']
comparison_df['Fixed_Enrollment_30yr'] = fixed_enrollment_30yr_df['AUC_median']
comparison_df['Fixed_Enrollment_Static_10yr'] = fixed_enrollment_static_10yr_df['AUC_median']
comparison_df['Fixed_Retrospective_10yr'] = fixed_retrospective_10yr_df['AUC_median']
comparison_df['Fixed_Retrospective_30yr'] = fixed_retrospective_30yr_df['AUC_median']
comparison_df['Fixed_Retrospective_Static_10yr'] = fixed_retrospective_static_10yr_df['AUC_median']

comparison_df.to_csv(f'{output_dir}pooled_comparison_all_approaches.csv')
print("✓ Saved combined comparison file: pooled_comparison_all_approaches.csv")

# Print summary
print(f"\n{'='*80}")
print("SUMMARY OF AGGREGATED RESULTS")
print(f"{'='*80}")
print(f"\nFixed Enrollment - 10yr: {len(fixed_enrollment_10yr_df)} diseases")
print(f"Fixed Enrollment - 30yr: {len(fixed_enrollment_30yr_df)} diseases")
print(f"Fixed Enrollment - Static 10yr: {len(fixed_enrollment_static_10yr_df)} diseases")
print(f"Fixed Retrospective - 10yr: {len(fixed_retrospective_10yr_df)} diseases")
print(f"Fixed Retrospective - 30yr: {len(fixed_retrospective_30yr_df)} diseases")
print(f"Fixed Retrospective - Static 10yr: {len(fixed_retrospective_static_10yr_df)} diseases")

print(f"\n{'='*80}")
print("TOP 10 DISEASES BY AUC (Fixed Enrollment 10yr)")
print(f"{'='*80}")
if not fixed_enrollment_10yr_df.empty:
    print(fixed_enrollment_10yr_df[['AUC_median', 'CI_lower_median', 'CI_upper_median', 'N_Batches']].head(10).round(4))

print(f"\n{'='*80}")
print("TOP 10 DISEASES BY AUC (Fixed Retrospective 10yr)")
print(f"{'='*80}")
if not fixed_retrospective_10yr_df.empty:
    print(fixed_retrospective_10yr_df[['AUC_median', 'CI_lower_median', 'CI_upper_median', 'N_Batches']].head(10).round(4))

print(f"\n{'='*80}")
print("All results saved successfully!")
print(f"{'='*80}")

├── comparisons/      # Cross-approach comparisons
│   ├── vs_delphi/
│   └── vs cox?
|   |_ vs PCE/PREVENT (implicit in functions?)



vs Delphi:

we need to do the newer compariosn with delphi (i don't get what teh supplement has in 0,1,2 there) but something is here (is this for 5 year comparisons in Schmatko 41586_2025_9529_MOESM3_ESM.csv?) 

"""
Reproducible Analysis: Aladynoulli vs Delphi Disease Prediction Comparison
=============================================================================

This script compares Aladynoulli (PheCode-based) predictions with Delphi (ICD-10 based) 
predictions across 1-year, 5-year, and 10-year time horizons.

Input files:
- washout_summary_table.csv: Aladynoulli 1-year predictions with washout analysis
- median_auc_results_5_year.csv: Aladynoulli 5-year predictions
- median_auc_results_10yearjointphi.csv: Aladynoulli 10-year predictions
- 41586_2025_9529_MOESM3_ESM.csv: Delphi supplementary table (ICD-10 level)
"""

import pandas as pd
import numpy as np

# =============================================================================
# 1. LOAD ALADYNOULLI RESULTS
# =============================================================================

print("Loading Aladynoulli results...")

# 1-year predictions (0 washout = all data available)
washout = pd.read_csv('/Users/sarahurbut/aladynoulli2/claudefile/output/washout_summary_table.csv')
aladynoulli_1yr = washout[['Disease', '0yr_AUC']].copy()
aladynoulli_1yr.columns = ['Disease', 'Aladynoulli_1yr']

# 5-year predictions
year_5 = pd.read_csv('/Users/sarahurbut/aladynoulli2/claudefile/output/median_auc_results_5yearjointphi.csv')
aladynoulli_5yr = year_5[['Disease', 'MedianAUC']].copy()
aladynoulli_5yr.columns = ['Disease', 'Aladynoulli_5yr']

# 10-year predictions
year_10 = pd.read_csv('/Users/sarahurbut/aladynoulli2/claudefile/output/median_auc_results_10yearjointphi.csv')
aladynoulli_10yr = year_10[['Disease', 'MedianAUC']].copy()
aladynoulli_10yr.columns = ['Disease', 'Aladynoulli_10yr']

# Merge all Aladynoulli results
aladynoulli_all = aladynoulli_1yr.merge(aladynoulli_5yr, on='Disease', how='outer')
aladynoulli_all = aladynoulli_all.merge(aladynoulli_10yr, on='Disease', how='outer')

print(f"Loaded Aladynoulli results for {len(aladynoulli_all)} diseases")

# =============================================================================
# 2. EXTRACT DELPHI RESULTS FROM SUPPLEMENTARY TABLE
# =============================================================================

print("\nExtracting Delphi results from supplementary table...")

# Load Delphi supplementary table (1,270 ICD-10 codes)
delphi_supp = pd.read_csv('/Users/sarahurbut/Downloads/41586_2025_9529_MOESM3_ESM.csv')

# Define disease category to ICD-10 code mappings
# (These are the major disease categories from the Aladynoulli analysis)
disease_icd_mapping = {
    'ASCVD': ['I21', 'I25'],  # Myocardial infarction, Coronary atherosclerosis
    'Diabetes': ['E11'],  # Type 2 diabetes
    'Atrial_Fib': ['I48'],  # Atrial fibrillation
    'CKD': ['N18'],  # Chronic renal failure
    'All_Cancers': ['C18', 'C50', 'D07'],  # Colon, Breast, Prostate
    'Stroke': ['I63'],  # Cerebral infarction
    'Heart_Failure': ['I50'],  # Heart failure
    'Pneumonia': ['J18'],  # Pneumonia
    'COPD': ['J44'],  # Chronic obstructive pulmonary disease
    'Osteoporosis': ['M81'],  # Osteoporosis
    'Anemia': ['D50'],  # Iron deficiency anemia
    'Colorectal_Cancer': ['C18'],  # Colon cancer
    'Breast_Cancer': ['C50'],  # Breast cancer
    'Prostate_Cancer': ['C61'],  # Prostate cancer
    'Lung_Cancer': ['C34'],  # Lung cancer
    'Bladder_Cancer': ['C67'],  # Bladder cancer
    'Secondary_Cancer': ['C79'],  # Secondary malignant neoplasm
    'Depression': ['F32', 'F33'],  # Depressive disorders
    'Anxiety': ['F41'],  # Anxiety disorders
    'Bipolar_Disorder': ['F31'],  # Bipolar disorder
    'Rheumatoid_Arthritis': ['M05', 'M06'],  # Rheumatoid arthritis
    'Psoriasis': ['L40'],  # Psoriasis
    'Ulcerative_Colitis': ['K51'],  # Ulcerative colitis
    'Crohns_Disease': ['K50'],  # Crohn's disease
    'Asthma': ['J45'],  # Asthma
    'Parkinsons': ['G20'],  # Parkinson's disease
    'Multiple_Sclerosis': ['G35'],  # Multiple sclerosis
    'Thyroid_Disorders': ['E03']  # Hypothyroidism
}

# Extract Delphi AUCs for each disease category
delphi_results = []

for disease_name, icd_codes in disease_icd_mapping.items():
    matching_rows = []
    
    for icd_code in icd_codes:
        # Find ICD-10 codes that start with the pattern
        matches = delphi_supp[delphi_supp['Name'].str.contains(f'^{icd_code}', regex=True, na=False)]
        if len(matches) > 0:
            matching_rows.append(matches)
    
    if len(matching_rows) > 0:
        # Combine all matching rows
        combined = pd.concat(matching_rows)
        
        # Average the AUCs (both male and female)
        female_aucs = combined['AUC Female, (no gap)'].dropna()
        male_aucs = combined['AUC Male, (no gap)'].dropna()
        
        if len(female_aucs) > 0 or len(male_aucs) > 0:
            all_aucs = pd.concat([female_aucs, male_aucs])
            avg_auc = all_aucs.mean()
            
            delphi_results.append({
                'Disease': disease_name,
                'Delphi_1yr': avg_auc,
                'N_ICD_codes': len(combined)
            })

delphi_df = pd.DataFrame(delphi_results)
print(f"Extracted Delphi results for {len(delphi_df)} diseases")

# =============================================================================
# 3. MERGE AND COMPARE
# =============================================================================

print("\nCreating comparison...")

# Merge all results
comparison = aladynoulli_all.merge(delphi_df[['Disease', 'Delphi_1yr']], on='Disease', how='outer')

# Calculate differences
comparison['Diff_1yr'] = comparison['Aladynoulli_1yr'] - comparison['Delphi_1yr']

# Sort by 1-year difference
comparison = comparison.sort_values('Diff_1yr', ascending=False)

# =============================================================================
# 4. IDENTIFY WINS
# =============================================================================

wins = comparison[comparison['Diff_1yr'] > 0].copy()

print("\n" + "=" * 100)
print("ALADYNOULLI vs DELPHI: DISEASES WHERE ALADYNOULLI WINS (1-YEAR PREDICTIONS)")
print("=" * 100)
print(f"\nTotal wins: {len(wins)} out of {len(comparison)} diseases")
print(f"Win rate: {len(wins)/len(comparison)*100:.1f}%\n")

pd.set_option('display.float_format', '{:.4f}'.format)

print(f"{'Disease':<25} {'Aladynoulli':>12} {'Delphi':>12} {'Advantage':>12} {'Percent':>10}")
print("-" * 100)

for idx, row in wins.iterrows():
    disease = row['Disease']
    ala = row['Aladynoulli_1yr']
    delp = row['Delphi_1yr']
    diff = row['Diff_1yr']
    pct = (diff / delp * 100) if delp > 0 else 0
    
    print(f"{disease:<25} {ala:>12.4f} {delp:>12.4f} {diff:>12.4f} {pct:>9.1f}%")

# =============================================================================
# 5. SUMMARY STATISTICS
# =============================================================================

print("\n" + "=" * 100)
print("SUMMARY STATISTICS")
print("=" * 100)

print(f"\n1-YEAR PREDICTIONS:")
print(f"  Aladynoulli mean (all):  {comparison['Aladynoulli_1yr'].mean():.4f}")
print(f"  Delphi mean (all):       {comparison['Delphi_1yr'].mean():.4f}")
print(f"  Overall difference:      {comparison['Diff_1yr'].mean():.4f}")

print(f"\n  Aladynoulli mean (wins): {wins['Aladynoulli_1yr'].mean():.4f}")
print(f"  Delphi mean (wins):      {wins['Delphi_1yr'].mean():.4f}")
print(f"  Average advantage:       {wins['Diff_1yr'].mean():.4f}")

print(f"\n5-YEAR PREDICTIONS:")
print(f"  Aladynoulli mean:        {comparison['Aladynoulli_5yr'].mean():.4f}")

print(f"\n10-YEAR PREDICTIONS:")
print(f"  Aladynoulli mean:        {comparison['Aladynoulli_10yr'].mean():.4f}")

# =============================================================================
# 6. TOP WINS BY CATEGORY
# =============================================================================

print("\n" + "=" * 100)
print("TOP 5 BIGGEST WINS")
print("=" * 100)

top5 = wins.head(5)
for i, (idx, row) in enumerate(top5.iterrows(), 1):
    print(f"\n{i}. {row['Disease']}")
    print(f"   Aladynoulli: {row['Aladynoulli_1yr']:.4f}")
    print(f"   Delphi:      {row['Delphi_1yr']:.4f}")
    print(f"   Advantage:   +{row['Diff_1yr']:.4f} ({row['Diff_1yr']/row['Delphi_1yr']*100:.1f}% better)")

# =============================================================================
# 7. SAVE RESULTS
# =============================================================================

print("\n" + "=" * 100)
print("SAVING RESULTS")
print("=" * 100)

# Save full comparison
comparison.to_csv('/Users/sarahurbut/aladynoulli2/claudefile/output//comparison_aladynoulli_vs_delphi_full.csv', index=False)
print("Full comparison saved to: /Users/sarahurbut/aladynoulli2/claudefile/output//comparison_aladynoulli_vs_delphi_full.csv")

# Save wins only
wins.to_csv('/Users/sarahurbut/aladynoulli2/claudefile/output//comparison_aladynoulli_vs_delphi_wins.csv', index=False)
print("Wins only saved to: /Users/sarahurbut/aladynoulli2/claudefile/output//comparison_aladynoulli_vs_delphi_wins.csv")

print("\n" + "=" * 100)
print("ANALYSIS COMPLETE")
print("=" * 100)

#### COMPARISON VIA COX

# vs cox (in cox baseline without noulli, calculated in R script tdccdoe20.R) # Compare Fixed_Retrospective_Pooled vs Cox Baseline (Age + Sex only) on UK Biobank
import pandas as pd
import numpy as np

# Load Cox baseline results (age + sex only, no Aladyn)
cox_df = pd.read_csv('/Users/sarahurbut/Library/CloudStorage/Dropbox/auc_results_cox_20000_30000train_0_10000test_1121.csv')
# Take first occurrence of each disease (remove duplicates)
cox_baseline = cox_df.groupby('disease_group')['auc'].first().to_dict()

# Load your Fixed_Retrospective_Pooled results (maybe replace 10 year with static?)
df_10yr = pd.read_csv('/Users/sarahurbut/aladynoulli2/pyScripts/new_oct_revision/new_notebooks/comparison_all_approaches_10yr.csv', index_col=0)
df_30yr = pd.read_csv('/Users/sarahurbut/aladynoulli2/pyScripts/new_oct_revision/new_notebooks/comparison_all_approaches_30yr.csv', index_col=0)

# Create comparison
comparison_data = []
for disease in df_10yr.index:
    if disease in cox_baseline:
        comparison_data.append({
            'Disease': disease,
            'Cox_Baseline_AUC': cox_baseline[disease],
            'Your_10yr_AUC': df_10yr.loc[disease, 'Fixed_Retrospective_Pooled'],
            'Your_30yr_AUC': df_30yr.loc[disease, 'Fixed_Retrospective_Pooled'],
            'Improvement_10yr': df_10yr.loc[disease, 'Fixed_Retrospective_Pooled'] - cox_baseline[disease],
            'Improvement_30yr': df_30yr.loc[disease, 'Fixed_Retrospective_Pooled'] - cox_baseline[disease],
        })

comparison_df = pd.DataFrame(comparison_data).set_index('Disease').sort_values('Improvement_10yr', ascending=False)

print("="*100)
print("COMPARISON: Your Fixed_Retrospective_Pooled vs Cox Baseline (Age + Sex only) on UK Biobank")
print("="*100)
print("\n10-YEAR PREDICTIONS:")
print("-"*100)
print(comparison_df[['Cox_Baseline_AUC', 'Your_10yr_AUC', 'Improvement_10yr']].round(4))
print(f"\nMean improvement: {comparison_df['Improvement_10yr'].mean():.4f}")
print(f"Median improvement: {comparison_df['Improvement_10yr'].median():.4f}")
print(f"Diseases with improvement >0.05: {(comparison_df['Improvement_10yr'] > 0.05).sum()} / {len(comparison_df)}")
print(f"Diseases with improvement >0.10: {(comparison_df['Improvement_10yr'] > 0.10).sum()} / {len(comparison_df)}")

print("\n" + "="*100)
print("30-YEAR PREDICTIONS:")
print("-"*100)
print(comparison_df[['Cox_Baseline_AUC', 'Your_30yr_AUC', 'Improvement_30yr']].round(4))
print(f"\nMean improvement: {comparison_df['Improvement_30yr'].mean():.4f}")
print(f"Median improvement: {comparison_df['Improvement_30yr'].median():.4f}")
print(f"Diseases with improvement >0.05: {(comparison_df['Improvement_30yr'] > 0.05).sum()} / {len(comparison_df)}")
print(f"Diseases with improvement >0.10: {(comparison_df['Improvement_30yr'] > 0.10).sum()} / {len(comparison_df)}")

print("\n" + "="*100)
print("TOP IMPROVEMENTS (10-year):")
print("-"*100)
print(comparison_df.nlargest(10, 'Improvement_10yr')[['Cox_Baseline_AUC', 'Your_10yr_AUC', 'Improvement_10yr']].round(4))

# Save comparison
comparison_df.to_csv('comparison_vs_cox_baseline.csv')
print("\n✓ Saved to comparison_vs_cox_baseline.csv")


Validation:
    loo and aws

* LOO: also in lifetime.ipynb # Compare Leave-One-Out vs Full Pooled results
# For batches that were excluded in LOO validation

# Batches excluded in LOO (from the folder list)
excluded_batches = [0, 6, 15, 17, 18, 20, 24, 34, 35, 37]

# Storage for results
loo_10yr_results = []
loo_30yr_results = []
loo_static_10yr_results = []

full_pooled_10yr_results = []
full_pooled_30yr_results = []
full_pooled_static_10yr_results = []

# Load full tensors once
if 'Y_full' not in globals():
    Y_full = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/Y_tensor.pt')
if 'E_full' not in globals():
    E_full = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/E_enrollment_full.pt')

# Loop through excluded batches
for batch_idx in excluded_batches:
    start_idx = batch_idx * 10000
    end_idx = (batch_idx + 1) * 10000
    
    print(f"\n{'='*80}")
    print(f"Processing batch {batch_idx}: {start_idx} to {end_idx}")
    print(f"{'='*80}")
    
    # Get pce_df subset for this batch
    pce_df_subset = pce_df_full[start_idx:end_idx].copy().reset_index(drop=True)
    
    # Extract batch from full tensors
    Y_batch = Y_full[start_idx:end_idx]
    E_batch = E_full[start_idx:end_idx]
    
    # ===== LEAVE-ONE-OUT RESULTS =====
    loo_ckpt_path = f'/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/leave_one_out_validation/batch_{batch_idx}/model_enroll_fixedphi_sex_{start_idx}_{end_idx}.pt'
    
    try:
        print(f"\n--- Leave-One-Out (excluded batch {batch_idx}) ---")
        loo_ckpt = torch.load(loo_ckpt_path, weights_only=False)
        model.load_state_dict(loo_ckpt['model_state_dict'])
        
        model.Y = torch.tensor(Y_batch, dtype=torch.float32)
        model.N = Y_batch.shape[0]
       
        # 10-year predictions
        print(f"LOO - 10 year predictions...")
        loo_10yr = evaluate_major_diseases_wsex_with_bootstrap_dynamic(
            model, Y_batch, E_batch, disease_names, pce_df_subset, 
            n_bootstraps=100, follow_up_duration_years=10, patient_indices=None
        )
        loo_10yr['batch_idx'] = batch_idx
        loo_10yr['analysis_type'] = 'leave_one_out'
        loo_10yr_results.append(loo_10yr)
        
        # 30-year predictions
        print(f"LOO - 30 year predictions...")
        loo_30yr = evaluate_major_diseases_wsex_with_bootstrap_dynamic(
            model, Y_batch, E_batch, disease_names, pce_df_subset, 
            n_bootstraps=100, follow_up_duration_years=30, patient_indices=None
        )
        loo_30yr['batch_idx'] = batch_idx
        loo_30yr['analysis_type'] = 'leave_one_out'
        loo_30yr_results.append(loo_30yr)
        
        # Static 10-year predictions
        print(f"LOO - Static 10 year predictions...")
        loo_static_10yr = evaluate_major_diseases_wsex_with_bootstrap(
            model=model,
            Y_100k=Y_batch,
            E_100k=E_batch,
            disease_names=disease_names,
            pce_df=pce_df_subset,
            n_bootstraps=100,
            follow_up_duration_years=10,
        )
        loo_static_10yr['batch_idx'] = batch_idx
        loo_static_10yr['analysis_type'] = 'leave_one_out'
        loo_static_10yr_results.append(loo_static_10yr)
        
    except FileNotFoundError:
        print(f"LOO checkpoint not found: {loo_ckpt_path}")
    except Exception as e:
        print(f"Error processing LOO checkpoint {batch_idx}: {e}")
        import traceback
        traceback.print_exc()
    
    # ===== FULL POOLED RESULTS =====
    full_pooled_ckpt_path = f'/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/enrollment_predictions_fixedphi_RETROSPECTIVE_pooled/model_enroll_fixedphi_sex_{start_idx}_{end_idx}.pt'
    
    try:
        print(f"\n--- Full Pooled (all 40 batches) ---")
        full_ckpt = torch.load(full_pooled_ckpt_path, weights_only=False)
        model.load_state_dict(full_ckpt['model_state_dict'])
        
        model.Y = torch.tensor(Y_batch, dtype=torch.float32)
        model.N = Y_batch.shape[0]
       
        # 10-year predictions
        print(f"Full Pooled - 10 year predictions...")
        full_10yr = evaluate_major_diseases_wsex_with_bootstrap_dynamic(
            model, Y_batch, E_batch, disease_names, pce_df_subset, 
            n_bootstraps=100, follow_up_duration_years=10, patient_indices=None
        )
        full_10yr['batch_idx'] = batch_idx
        full_10yr['analysis_type'] = 'full_pooled'
        full_pooled_10yr_results.append(full_10yr)
        
        # 30-year predictions
        print(f"Full Pooled - 30 year predictions...")
        full_30yr = evaluate_major_diseases_wsex_with_bootstrap_dynamic(
            model, Y_batch, E_batch, disease_names, pce_df_subset, 
            n_bootstraps=100, follow_up_duration_years=30, patient_indices=None
        )
        full_30yr['batch_idx'] = batch_idx
        full_30yr['analysis_type'] = 'full_pooled'
        full_pooled_30yr_results.append(full_30yr)
        
        # Static 10-year predictions
        print(f"Full Pooled - Static 10 year predictions...")
        full_static_10yr = evaluate_major_diseases_wsex_with_bootstrap(
            model=model,
            Y_100k=Y_batch,
            E_100k=E_batch,
            disease_names=disease_names,
            pce_df=pce_df_subset,
            n_bootstraps=100,
            follow_up_duration_years=10,
        )
        full_static_10yr['batch_idx'] = batch_idx
        full_static_10yr['analysis_type'] = 'full_pooled'
        full_pooled_static_10yr_results.append(full_static_10yr)
        
    except FileNotFoundError:
        print(f"Full pooled checkpoint not found: {full_pooled_ckpt_path}")
    except Exception as e:
        print(f"Error processing full pooled checkpoint {batch_idx}: {e}")
        import traceback
        traceback.print_exc()

print(f"\n{'='*80}")
print("Completed processing!")
print(f"{'='*80}")
print(f"LOO - 10yr: {len(loo_10yr_results)} batches")
print(f"LOO - 30yr: {len(loo_30yr_results)} batches")
print(f"Full Pooled - 10yr: {len(full_pooled_10yr_results)} batches")
print(f"Full Pooled - 30yr: {len(full_pooled_30yr_results)} batches")

# Extract AUCs and compare
def extract_aucs_from_results(results_list):
    aucs_by_batch = {}
    for result in results_list:
        batch_idx = result['batch_idx']
        if batch_idx not in aucs_by_batch:
            aucs_by_batch[batch_idx] = {}
        for disease, metrics in result.items():
            if disease not in ['batch_idx', 'analysis_type'] and isinstance(metrics, dict):
                if 'auc' in metrics:
                    aucs_by_batch[batch_idx][disease] = metrics['auc']
    return aucs_by_batch

loo_10yr_aucs = extract_aucs_from_results(loo_10yr_results)
loo_30yr_aucs = extract_aucs_from_results(loo_30yr_results)
loo_static_10yr_aucs = extract_aucs_from_results(loo_static_10yr_results)

full_10yr_aucs = extract_aucs_from_results(full_pooled_10yr_results)
full_30yr_aucs = extract_aucs_from_results(full_pooled_30yr_results)
full_static_10yr_aucs = extract_aucs_from_results(full_pooled_static_10yr_results)

# Compare using the same function
compare_results(loo_10yr_aucs, full_10yr_aucs, "LEAVE-ONE-OUT vs FULL POOLED - 10-YEAR PREDICTIONS")
compare_results(loo_30yr_aucs, full_30yr_aucs, "LEAVE-ONE-OUT vs FULL POOLED - 30-YEAR PREDICTIONS")
compare_results(loo_static_10yr_aucs, full_static_10yr_aucs, "LEAVE-ONE-OUT vs FULL POOLED - STATIC 10-YEAR PREDICTIONS")


### and then AWS validation (versus AWS)for ten orso baches ... 


from fig5utils import *
import pandas as pd
import numpy as np
import torch

# Load full pce_df
pce_df_full = pd.read_csv('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/pce_prevent_full.csv')
disease_names = essentials['disease_names']

# Storage for results - SEPARATE variables for each analysis type
# Fixed phi from retrospective AWS data
aws_10yr_results = []
aws_30yr_results = []
aws_static_10yr_results = []

# Fixed phi from RETROSPECTIVE data run locally
fixed_retrospective_10yr_results = []
fixed_retrospective_30yr_results = []
fixed_retrospective_static_10yr_results = []

# Load full tensors once (shared across both analyses)
if 'Y_full' not in globals():
    Y_full = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/Y_tensor.pt')
if 'E_full' not in globals():
    E_full = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/E_enrollment_full.pt')

# Loop through checkpoints 0-10 (10 batches)
for batch_idx in range(11):
    start_idx = batch_idx * 10000
    end_idx = (batch_idx + 1) * 10000
    
    print(f"\n{'='*80}")
    print(f"Processing batch {batch_idx}: {start_idx} to {end_idx}")
    print(f"{'='*80}")
    
    # Get pce_df subset for this batch
    pce_df_subset = pce_df_full[start_idx:end_idx].copy().reset_index(drop=True)
    
    # Extract batch from full tensors (shared for both analyses)
    Y_batch = Y_full[start_idx:end_idx]
    E_batch = E_full[start_idx:end_idx]
    
    # ===== FIXED PHI FROMAWS POOLED DATA =====
    fixed_enrollment_ckpt_path = f'/Users/sarahurbut/Downloads/aws_first_10_batches_models/model_enroll_fixedphi_sex_{start_idx}_{end_idx}.pt'
    
    try:
        print(f"\n--- Fixed Phi (retrospective AWS) ---")
        fixed_ckpt = torch.load(fixed_enrollment_ckpt_path, weights_only=False)
        model.load_state_dict(fixed_ckpt['model_state_dict'])
        
        # Update model.Y and model.N so forward() uses correct patients
        model.Y = torch.tensor(Y_batch, dtype=torch.float32)
        model.N = Y_batch.shape[0]
       
        # 10-year predictions
        print(f"Fixed Phi (retrospective AWS) - 10 year predictions...")
        aws_10yr = evaluate_major_diseases_wsex_with_bootstrap_dynamic(
            model, Y_batch, E_batch, disease_names, pce_df_subset, 
            n_bootstraps=100, follow_up_duration_years=10, patient_indices=None
        )
        aws_10yr['batch_idx'] = batch_idx
        aws_10yr['analysis_type'] = 'fixed_enrollment'
        aws_10yr_results.append(aws_10yr)
        
        # 30-year predictions
        print(f"Fixed Phi (retrospective AWS) - 30 year predictions...")
        aws_30yr = evaluate_major_diseases_wsex_with_bootstrap_dynamic(
            model, Y_batch, E_batch, disease_names, pce_df_subset, 
            n_bootstraps=100, follow_up_duration_years=30, patient_indices=None
        )
        aws_30yr['batch_idx'] = batch_idx
        aws_30yr['analysis_type'] = 'fixed_enrollment'
        aws_30yr_results.append(aws_30yr)
        
        # Static 10-year predictions (using 1-year score)
        print(f"Fixed Phi (retrospective AWS) - Static 10 year predictions...")
        aws_static_10yr = evaluate_major_diseases_wsex_with_bootstrap(
            model=model,
            Y_100k=Y_batch,
            E_100k=E_batch,
            disease_names=disease_names,
            pce_df=pce_df_subset,
            n_bootstraps=100,
            follow_up_duration_years=10,
        )
        aws_static_10yr['batch_idx'] = batch_idx
        aws_static_10yr['analysis_type'] = 'fixed_enrollment'
        aws_static_10yr_results.append(aws_static_10yr)
        
    except FileNotFoundError:
        print(f"Fixed phi (ENROLLMENT) checkpoint not found: {fixed_enrollment_ckpt_path}")
    except Exception as e:
        print(f"Error processing fixed phi (ENROLLMENT) checkpoint {batch_idx}: {e}")
        import traceback
        traceback.print_exc()
    
    # ===== FIXED PHI FROM RETROSPECTIVE DATA =====
    fixed_retrospective_ckpt_path = f'/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/enrollment_predictions_fixedphi_RETROSPECTIVE_pooled/model_enroll_fixedphi_sex_{start_idx}_{end_idx}.pt'
    
    try:
        print(f"\n--- Fixed Phi (RETROSPECTIVE) ---")
        fixed_ckpt = torch.load(fixed_retrospective_ckpt_path, weights_only=False)
        model.load_state_dict(fixed_ckpt['model_state_dict'])
        
        # Update model.Y and model.N so forward() uses correct patients
        model.Y = torch.tensor(Y_batch, dtype=torch.float32)
        model.N = Y_batch.shape[0]
       
        # 10-year predictions
        print(f"Fixed Phi (RETROSPECTIVE) - 10 year predictions...")
        fixed_10yr = evaluate_major_diseases_wsex_with_bootstrap_dynamic(
            model, Y_batch, E_batch, disease_names, pce_df_subset, 
            n_bootstraps=100, follow_up_duration_years=10, patient_indices=None
        )
        fixed_10yr['batch_idx'] = batch_idx
        fixed_10yr['analysis_type'] = 'fixed_retrospective'
        fixed_retrospective_10yr_results.append(fixed_10yr)
        
        # 30-year predictions
        print(f"Fixed Phi (RETROSPECTIVE) - 30 year predictions...")
        fixed_30yr = evaluate_major_diseases_wsex_with_bootstrap_dynamic(
            model, Y_batch, E_batch, disease_names, pce_df_subset, 
            n_bootstraps=100, follow_up_duration_years=30, patient_indices=None
        )
        fixed_30yr['batch_idx'] = batch_idx
        fixed_30yr['analysis_type'] = 'fixed_retrospective'
        fixed_retrospective_30yr_results.append(fixed_30yr)
        
        # Static 10-year predictions (using 1-year score)
        print(f"Fixed Phi (RETROSPECTIVE) - Static 10 year predictions...")
        fixed_static_10yr = evaluate_major_diseases_wsex_with_bootstrap(
            model=model,
            Y_100k=Y_batch,
            E_100k=E_batch,
            disease_names=disease_names,
            pce_df=pce_df_subset,
            n_bootstraps=100,
            follow_up_duration_years=10,
        )
        fixed_static_10yr['batch_idx'] = batch_idx
        fixed_static_10yr['analysis_type'] = 'fixed_retrospective'
        fixed_retrospective_static_10yr_results.append(fixed_static_10yr)
        
    except FileNotFoundError:
        print(f"Fixed phi (RETROSPECTIVE) checkpoint not found: {fixed_retrospective_ckpt_path}")
    except Exception as e:
        print(f"Error processing fixed phi (RETROSPECTIVE) checkpoint {batch_idx}: {e}")
        import traceback
        traceback.print_exc()

print(f"\n{'='*80}")
print("Completed processing all checkpoints!")
print(f"{'='*80}")
print(f"Fixed Enrollment - 10yr: {len(fixed_enrollment_10yr_results)} batches")
print(f"Fixed Enrollment - 30yr: {len(fixed_enrollment_30yr_results)} batches")
print(f"Fixed Retrospective - 10yr: {len(fixed_retrospective_10yr_results)} batches")
print(f"Fixed Retrospective - 30yr: {len(fixed_retrospective_30yr_results)} batches")
)