In [None]:
# Compare Fixed_Retrospective_Pooled vs Cox Baseline (Age + Sex only) on UK Biobank
import pandas as pd
import numpy as np

# Load Cox baseline results (age + sex only, no Aladyn)
cox_df = pd.read_csv('/Users/sarahurbut/Library/CloudStorage/Dropbox/auc_results_cox_20000_30000train_0_10000test_1121.csv')
# Take first occurrence of each disease (remove duplicates)
cox_baseline = cox_df.groupby('disease_group')['auc'].first().to_dict()

# Load your Fixed_Retrospective_Pooled results
df_10yr = pd.read_csv('/Users/sarahurbut/aladynoulli2/pyScripts/new_oct_revision/new_notebooks/comparison_all_approaches_10yr.csv', index_col=0)
df_30yr = pd.read_csv('/Users/sarahurbut/aladynoulli2/pyScripts/new_oct_revision/new_notebooks/comparison_all_approaches_30yr.csv', index_col=0)

# Create comparison
comparison_data = []
for disease in df_10yr.index:
    if disease in cox_baseline:
        comparison_data.append({
            'Disease': disease,
            'Cox_Baseline_AUC': cox_baseline[disease],
            'Your_10yr_AUC': df_10yr.loc[disease, 'Fixed_Retrospective_Pooled'],
            'Your_30yr_AUC': df_30yr.loc[disease, 'Fixed_Retrospective_Pooled'],
            'Improvement_10yr': df_10yr.loc[disease, 'Fixed_Retrospective_Pooled'] - cox_baseline[disease],
            'Improvement_30yr': df_30yr.loc[disease, 'Fixed_Retrospective_Pooled'] - cox_baseline[disease],
        })

comparison_df = pd.DataFrame(comparison_data).set_index('Disease').sort_values('Improvement_10yr', ascending=False)

print("="*100)
print("COMPARISON: Your Fixed_Retrospective_Pooled vs Cox Baseline (Age + Sex only) on UK Biobank")
print("="*100)
print("\n10-YEAR PREDICTIONS:")
print("-"*100)
print(comparison_df[['Cox_Baseline_AUC', 'Your_10yr_AUC', 'Improvement_10yr']].round(4))
print(f"\nMean improvement: {comparison_df['Improvement_10yr'].mean():.4f}")
print(f"Median improvement: {comparison_df['Improvement_10yr'].median():.4f}")
print(f"Diseases with improvement >0.05: {(comparison_df['Improvement_10yr'] > 0.05).sum()} / {len(comparison_df)}")
print(f"Diseases with improvement >0.10: {(comparison_df['Improvement_10yr'] > 0.10).sum()} / {len(comparison_df)}")

print("\n" + "="*100)
print("30-YEAR PREDICTIONS:")
print("-"*100)
print(comparison_df[['Cox_Baseline_AUC', 'Your_30yr_AUC', 'Improvement_30yr']].round(4))
print(f"\nMean improvement: {comparison_df['Improvement_30yr'].mean():.4f}")
print(f"Median improvement: {comparison_df['Improvement_30yr'].median():.4f}")
print(f"Diseases with improvement >0.05: {(comparison_df['Improvement_30yr'] > 0.05).sum()} / {len(comparison_df)}")
print(f"Diseases with improvement >0.10: {(comparison_df['Improvement_30yr'] > 0.10).sum()} / {len(comparison_df)}")

print("\n" + "="*100)
print("TOP IMPROVEMENTS (10-year):")
print("-"*100)
print(comparison_df.nlargest(10, 'Improvement_10yr')[['Cox_Baseline_AUC', 'Your_10yr_AUC', 'Improvement_10yr']].round(4))

# Save comparison
comparison_df.to_csv('comparison_vs_cox_baseline.csv')
print("\n✓ Saved to comparison_vs_cox_baseline.csv")


In [None]:
import os
import sys
import torch
%load_ext autoreload
%autoreload 2
sys.path.append('/Users/sarahurbut/aladynoulli2/pyScripts/')



In [None]:
ckpt=torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/enrollment_retrospective_full/enrollment_model_W0.0001_batch_0_10000.pt')


In [None]:
%load_ext autoreload
%autoreload 2

%autoreload 2
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from scipy.spatial.distance import pdist, squareform
from scipy.special import expit
from scipy.stats import multivariate_normal
import matplotlib.pyplot as plt
from sklearn.cluster import SpectralClustering  # Add this import
import rpy2.robjects as robjects
from rpy2.robjects import pandas2ri
pandas2ri.activate()

def load_model_essentials(base_path='/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/'):
    """
    Load all essential components
    """
    print("Loading components...")
    
    # Load large matrices
    Y = torch.load(base_path + 'Y_tensor.pt')
    E = torch.load(base_path + 'E_matrix.pt')
    G = torch.load(base_path + 'G_matrix.pt')
    
    # Load other components
    essentials = torch.load(base_path + 'model_essentials.pt')
    
    print("Loaded all components successfully!")
    
    return Y, E, G, essentials

# Load and initialize model:
Y, E, G, essentials = load_model_essentials()
from clust_huge_amp import *
# Subset the data
Y_100k, E_100k, G_100k, indices = subset_data(Y, E, G, start_index=0, end_index=10000)

torch.manual_seed(42)
np.random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Initialize model with subsetted data

del Y

# Load references (signatures only, no healthy)
refs = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/reference_trajectories.pt')
signature_refs = refs['signature_refs']
# When initializing the model:

readRDS = robjects.r['readRDS']
pce_data = readRDS('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/pce_df_prevent.rds')
pce_df = pandas2ri.rpy2py(pce_data)  # Convert to pandas DataFrame
sex=pce_df['Sex'].values

# Convert to numeric: Female=0, Male=1

pce_df['sex_numeric'] = pce_df['Sex'].map({'Female': 0, 'Male': 1}).astype(int)

sex=pce_df['sex_numeric'].values
G_with_sex = ckpt['G']  # sex should be numeric (e.g., 0/1)
# N


model = AladynSurvivalFixedKernelsAvgLoss_clust_logitInit_psitest(
    N=Y_100k.shape[0], 
    D=Y_100k.shape[1], 
    T=Y_100k.shape[2], 
    K=20,
    P=G_with_sex.shape[1],
    init_sd_scaler=1e-1,
    G=G_with_sex, 
    Y=Y_100k,
    genetic_scale=1,
    W=0,
    R=0,
    prevalence_t=essentials['prevalence_t'],
    signature_references=signature_refs,  # Only pass signature refs
    healthy_reference=True,  # Explicitly set to None
    disease_names=essentials['disease_names']
)

torch.manual_seed(0)
np.random.seed(0)
# Initialize with psi and clusters


import cProfile
import pstats
from pstats import SortKey

# Now in your batch run, load and verify:
initial_psi = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/initial_psi_400k.pt')
initial_clusters = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/initial_clusters_400k.pt')

model.initialize_params(true_psi=initial_psi)
model.clusters = initial_clusters
# Verify clusters match
clusters_match = np.array_equal(initial_clusters, model.clusters)
print(f"\nClusters match exactly: {clusters_match}")


## comparing AWS versus local 

In [None]:
pce_df_full = pd.read_csv('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/pce_prevent_full.csv')


In [None]:
pce_df_full['age']

In [None]:
covariates_path = '/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/baselinagefamh_withpcs.csv'
fh_processed = pd.read_csv(covariates_path)
fh_processed['age']

In [None]:
from fig5utils import *
import pandas as pd
import numpy as np
import torch

# Load full pce_df
pce_df_full = pd.read_csv('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/pce_prevent_full.csv')
disease_names = essentials['disease_names']

# Storage for results - SEPARATE variables for each analysis type
# Fixed phi from retrospective AWS data
aws_10yr_results = []
aws_30yr_results = []
aws_static_10yr_results = []

# Fixed phi from RETROSPECTIVE data run locally
fixed_retrospective_10yr_results = []
fixed_retrospective_30yr_results = []
fixed_retrospective_static_10yr_results = []

# Load full tensors once (shared across both analyses)
if 'Y_full' not in globals():
    Y_full = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/Y_tensor.pt')
if 'E_full' not in globals():
    E_full = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/E_enrollment_full.pt')

# Loop through checkpoints 0-10 (10 batches)
for batch_idx in range(11):
    start_idx = batch_idx * 10000
    end_idx = (batch_idx + 1) * 10000
    
    print(f"\n{'='*80}")
    print(f"Processing batch {batch_idx}: {start_idx} to {end_idx}")
    print(f"{'='*80}")
    
    # Get pce_df subset for this batch
    pce_df_subset = pce_df_full[start_idx:end_idx].copy().reset_index(drop=True)
    
    # Extract batch from full tensors (shared for both analyses)
    Y_batch = Y_full[start_idx:end_idx]
    E_batch = E_full[start_idx:end_idx]
    
    # ===== FIXED PHI FROMAWS POOLED DATA =====
    fixed_enrollment_ckpt_path = f'/Users/sarahurbut/Downloads/aws_first_10_batches_models/model_enroll_fixedphi_sex_{start_idx}_{end_idx}.pt'
    
    try:
        print(f"\n--- Fixed Phi (retrospective AWS) ---")
        fixed_ckpt = torch.load(fixed_enrollment_ckpt_path, weights_only=False)
        model.load_state_dict(fixed_ckpt['model_state_dict'])
        
        # Update model.Y and model.N so forward() uses correct patients
        model.Y = torch.tensor(Y_batch, dtype=torch.float32)
        model.N = Y_batch.shape[0]
       
        # 10-year predictions
        print(f"Fixed Phi (retrospective AWS) - 10 year predictions...")
        aws_10yr = evaluate_major_diseases_wsex_with_bootstrap_dynamic(
            model, Y_batch, E_batch, disease_names, pce_df_subset, 
            n_bootstraps=100, follow_up_duration_years=10, patient_indices=None
        )
        aws_10yr['batch_idx'] = batch_idx
        aws_10yr['analysis_type'] = 'fixed_enrollment'
        aws_10yr_results.append(aws_10yr)
        
        # 30-year predictions
        print(f"Fixed Phi (retrospective AWS) - 30 year predictions...")
        aws_30yr = evaluate_major_diseases_wsex_with_bootstrap_dynamic(
            model, Y_batch, E_batch, disease_names, pce_df_subset, 
            n_bootstraps=100, follow_up_duration_years=30, patient_indices=None
        )
        aws_30yr['batch_idx'] = batch_idx
        aws_30yr['analysis_type'] = 'fixed_enrollment'
        aws_30yr_results.append(aws_30yr)
        
        # Static 10-year predictions (using 1-year score)
        print(f"Fixed Phi (retrospective AWS) - Static 10 year predictions...")
        aws_static_10yr = evaluate_major_diseases_wsex_with_bootstrap(
            model=model,
            Y_100k=Y_batch,
            E_100k=E_batch,
            disease_names=disease_names,
            pce_df=pce_df_subset,
            n_bootstraps=100,
            follow_up_duration_years=10,
        )
        aws_static_10yr['batch_idx'] = batch_idx
        aws_static_10yr['analysis_type'] = 'fixed_enrollment'
        aws_static_10yr_results.append(aws_static_10yr)
        
    except FileNotFoundError:
        print(f"Fixed phi (ENROLLMENT) checkpoint not found: {fixed_enrollment_ckpt_path}")
    except Exception as e:
        print(f"Error processing fixed phi (ENROLLMENT) checkpoint {batch_idx}: {e}")
        import traceback
        traceback.print_exc()
    
    # ===== FIXED PHI FROM RETROSPECTIVE DATA =====
    fixed_retrospective_ckpt_path = f'/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/enrollment_predictions_fixedphi_RETROSPECTIVE_pooled/model_enroll_fixedphi_sex_{start_idx}_{end_idx}.pt'
    
    try:
        print(f"\n--- Fixed Phi (RETROSPECTIVE) ---")
        fixed_ckpt = torch.load(fixed_retrospective_ckpt_path, weights_only=False)
        model.load_state_dict(fixed_ckpt['model_state_dict'])
        
        # Update model.Y and model.N so forward() uses correct patients
        model.Y = torch.tensor(Y_batch, dtype=torch.float32)
        model.N = Y_batch.shape[0]
       
        # 10-year predictions
        print(f"Fixed Phi (RETROSPECTIVE) - 10 year predictions...")
        fixed_10yr = evaluate_major_diseases_wsex_with_bootstrap_dynamic(
            model, Y_batch, E_batch, disease_names, pce_df_subset, 
            n_bootstraps=100, follow_up_duration_years=10, patient_indices=None
        )
        fixed_10yr['batch_idx'] = batch_idx
        fixed_10yr['analysis_type'] = 'fixed_retrospective'
        fixed_retrospective_10yr_results.append(fixed_10yr)
        
        # 30-year predictions
        print(f"Fixed Phi (RETROSPECTIVE) - 30 year predictions...")
        fixed_30yr = evaluate_major_diseases_wsex_with_bootstrap_dynamic(
            model, Y_batch, E_batch, disease_names, pce_df_subset, 
            n_bootstraps=100, follow_up_duration_years=30, patient_indices=None
        )
        fixed_30yr['batch_idx'] = batch_idx
        fixed_30yr['analysis_type'] = 'fixed_retrospective'
        fixed_retrospective_30yr_results.append(fixed_30yr)
        
        # Static 10-year predictions (using 1-year score)
        print(f"Fixed Phi (RETROSPECTIVE) - Static 10 year predictions...")
        fixed_static_10yr = evaluate_major_diseases_wsex_with_bootstrap(
            model=model,
            Y_100k=Y_batch,
            E_100k=E_batch,
            disease_names=disease_names,
            pce_df=pce_df_subset,
            n_bootstraps=100,
            follow_up_duration_years=10,
        )
        fixed_static_10yr['batch_idx'] = batch_idx
        fixed_static_10yr['analysis_type'] = 'fixed_retrospective'
        fixed_retrospective_static_10yr_results.append(fixed_static_10yr)
        
    except FileNotFoundError:
        print(f"Fixed phi (RETROSPECTIVE) checkpoint not found: {fixed_retrospective_ckpt_path}")
    except Exception as e:
        print(f"Error processing fixed phi (RETROSPECTIVE) checkpoint {batch_idx}: {e}")
        import traceback
        traceback.print_exc()

print(f"\n{'='*80}")
print("Completed processing all checkpoints!")
print(f"{'='*80}")
print(f"Fixed Enrollment - 10yr: {len(fixed_enrollment_10yr_results)} batches")
print(f"Fixed Enrollment - 30yr: {len(fixed_enrollment_30yr_results)} batches")
print(f"Fixed Retrospective - 10yr: {len(fixed_retrospective_10yr_results)} batches")
print(f"Fixed Retrospective - 30yr: {len(fixed_retrospective_30yr_results)} batches")


In [None]:
# Compare AWS vs Local results per batch
import pandas as pd
import numpy as np

def extract_aucs_from_results(results_list):
    """Extract AUCs from results list into a dictionary by batch and disease"""
    aucs_by_batch = {}
    for result in results_list:
        batch_idx = result['batch_idx']
        if batch_idx not in aucs_by_batch:
            aucs_by_batch[batch_idx] = {}
        for disease, metrics in result.items():
            if disease not in ['batch_idx', 'analysis_type'] and isinstance(metrics, dict):
                if 'auc' in metrics:
                    aucs_by_batch[batch_idx][disease] = metrics['auc']
    return aucs_by_batch

# Extract AUCs for all result types
aws_10yr_aucs = extract_aucs_from_results(aws_10yr_results)
aws_30yr_aucs = extract_aucs_from_results(aws_30yr_results)
aws_static_10yr_aucs = extract_aucs_from_results(aws_static_10yr_results)

local_10yr_aucs = extract_aucs_from_results(fixed_retrospective_10yr_results)
local_30yr_aucs = extract_aucs_from_results(fixed_retrospective_30yr_results)
local_static_10yr_aucs = extract_aucs_from_results(fixed_retrospective_static_10yr_results)

def compare_results(aws_aucs, local_aucs, title):
    """Compare AWS vs Local AUCs per batch"""
    print(f"\n{'='*100}")
    print(f"{title}")
    print(f"{'='*100}")
    
    all_differences = []
    
    for batch_idx in sorted(set(list(aws_aucs.keys()) + list(local_aucs.keys()))):
        print(f"\n{'='*80}")
        print(f"BATCH {batch_idx}")
        print(f"{'='*80}")
        
        aws_batch = aws_aucs.get(batch_idx, {})
        local_batch = local_aucs.get(batch_idx, {})
        
        common_diseases = set(aws_batch.keys()) & set(local_batch.keys())
        
        if not common_diseases:
            print("No common diseases found")
            continue
        
        print(f"\n{'Disease':<30} {'AWS':<12} {'Local':<12} {'Difference':<12} {'Match':<8}")
        print("-"*80)
        
        differences = []
        for disease in sorted(common_diseases):
            aws_auc = aws_batch[disease]
            local_auc = local_batch[disease]
            diff = abs(aws_auc - local_auc)
            differences.append(diff)
            all_differences.append(diff)
            match = "✓" if diff < 0.01 else "⚠" if diff < 0.05 else "✗"
            print(f"{disease:<30} {aws_auc:<12.4f} {local_auc:<12.4f} {diff:<12.4f} {match:<8}")
        
        print(f"\nBatch {batch_idx} Summary:")
        print(f"  Mean difference: {sum(differences)/len(differences):.4f}")
        print(f"  Max difference: {max(differences):.4f}")
        print(f"  Min difference: {min(differences):.4f}")
        print(f"  Diseases with diff < 0.01: {sum(1 for d in differences if d < 0.01)}/{len(differences)}")
        print(f"  Diseases with diff < 0.05: {sum(1 for d in differences if d < 0.05)}/{len(differences)}")
    
    if all_differences:
        print(f"\n{'='*80}")
        print(f"OVERALL SUMMARY ({title})")
        print(f"{'='*80}")
        print(f"Mean difference: {sum(all_differences)/len(all_differences):.4f}")
        print(f"Max difference: {max(all_differences):.4f}")
        print(f"Min difference: {min(all_differences):.4f}")
        print(f"Median difference: {np.median(all_differences):.4f}")
        print(f"Std difference: {np.std(all_differences):.4f}")
        print(f"Total comparisons: {len(all_differences)}")
        print(f"Diseases with diff < 0.01: {sum(1 for d in all_differences if d < 0.01)}/{len(all_differences)} ({100*sum(1 for d in all_differences if d < 0.01)/len(all_differences):.1f}%)")
        print(f"Diseases with diff < 0.05: {sum(1 for d in all_differences if d < 0.05)}/{len(all_differences)} ({100*sum(1 for d in all_differences if d < 0.05)/len(all_differences):.1f}%)")

# Compare 10-year predictions
compare_results(aws_10yr_aucs, local_10yr_aucs, "AWS vs LOCAL - 10-YEAR PREDICTIONS")

# Compare 30-year predictions
compare_results(aws_30yr_aucs, local_30yr_aucs, "AWS vs LOCAL - 30-YEAR PREDICTIONS")

# Compare static 10-year predictions
compare_results(aws_static_10yr_aucs, local_static_10yr_aucs, "AWS vs LOCAL - STATIC 10-YEAR PREDICTIONS")

In [None]:
# ============================================================================
# SAVE AWS vs LOCAL COMPARISON RESULTS TO CSV AND CREATE VISUALIZATIONS
# ============================================================================

# Extract all differences from AWS vs Local comparisons
def extract_aws_local_differences(aws_aucs, local_aucs, prediction_type):
    """Extract differences between AWS and Local AUCs"""
    all_differences = []
    
    # FIX: Use set() instead of list() for intersection
    for batch_idx in sorted(set(aws_aucs.keys()) & set(local_aucs.keys())):
        aws_batch = aws_aucs.get(batch_idx, {})
        local_batch = local_aucs.get(batch_idx, {})
        
        common_diseases = set(aws_batch.keys()) & set(local_batch.keys())
        
        for disease in common_diseases:
            aws_auc = aws_batch[disease]
            local_auc = local_batch[disease]
            diff = abs(aws_auc - local_auc)
            
            all_differences.append({
                'batch_idx': batch_idx,
                'disease': disease,
                'aws_auc': aws_auc,
                'local_auc': local_auc,
                'difference': diff,
                'prediction_type': prediction_type
            })
    
    return all_differences

# Extract differences for all prediction types
all_aws_local_diffs = []
all_aws_local_diffs.extend(extract_aws_local_differences(aws_10yr_aucs, local_10yr_aucs, '10-Year'))
all_aws_local_diffs.extend(extract_aws_local_differences(aws_30yr_aucs, local_30yr_aucs, '30-Year'))
all_aws_local_diffs.extend(extract_aws_local_differences(aws_static_10yr_aucs, local_static_10yr_aucs, 'Static 10-Year'))

# Create DataFrame
df_aws_local = pd.DataFrame(all_aws_local_diffs)

# Save to CSV
csv_path_aws_local = '/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/aws_vs_local_comparison_results.csv'
df_aws_local.to_csv(csv_path_aws_local, index=False)
print(f"✓ Saved AWS vs Local comparison results to: {csv_path_aws_local}")
print(f"  Total comparisons: {len(df_aws_local)}")
print(f"  Batches: {sorted(df_aws_local['batch_idx'].unique())}")
print(f"  Prediction types: {df_aws_local['prediction_type'].unique()}")

# ============================================================================
# CREATE VISUALIZATIONS FOR AWS vs LOCAL
# ============================================================================

fig2 = plt.figure(figsize=(16, 10))

# 1. Distribution of differences by prediction type
ax1 = plt.subplot(2, 3, 1)
for pred_type in df_aws_local['prediction_type'].unique():
    diffs = df_aws_local[df_aws_local['prediction_type'] == pred_type]['difference'] * 1000
    ax1.hist(diffs, bins=30, alpha=0.6, label=pred_type, edgecolor='black', linewidth=0.5)

ax1.set_xlabel('AUC Difference (×1000)', fontsize=11, fontweight='bold')
ax1.set_ylabel('Frequency', fontsize=11, fontweight='bold')
ax1.set_title('Distribution of Differences\nby Prediction Type', fontsize=12, fontweight='bold')
ax1.legend(fontsize=10)
ax1.grid(alpha=0.3, linestyle='--')

# 2. Box plot of differences by prediction type
ax2 = plt.subplot(2, 3, 2)
df_aws_local['difference_x1000'] = df_aws_local['difference'] * 1000
sns.boxplot(data=df_aws_local, x='prediction_type', y='difference_x1000', ax=ax2, palette='Set1')
ax2.set_xlabel('Prediction Type', fontsize=11, fontweight='bold')
ax2.set_ylabel('AUC Difference (×1000)', fontsize=11, fontweight='bold')
ax2.set_title('Distribution of Differences\n(Box Plot)', fontsize=12, fontweight='bold')
ax2.grid(axis='y', alpha=0.3, linestyle='--')

# 3. Max difference per batch
ax3 = plt.subplot(2, 3, 3)
batch_max_diffs = df_aws_local.groupby(['batch_idx', 'prediction_type'])['difference'].max().reset_index()
for pred_type in df_aws_local['prediction_type'].unique():
    batch_data = batch_max_diffs[batch_max_diffs['prediction_type'] == pred_type]
    ax3.plot(batch_data['batch_idx'], batch_data['difference'] * 1000, 
             marker='o', linewidth=2, markersize=8, label=pred_type)

ax3.set_xlabel('Batch Index', fontsize=11, fontweight='bold')
ax3.set_ylabel('Max AUC Difference (×1000)', fontsize=11, fontweight='bold')
ax3.set_title('Max Difference per Batch', fontsize=12, fontweight='bold')
ax3.legend(fontsize=10)
ax3.grid(alpha=0.3, linestyle='--')
ax3.set_xticks(sorted(df_aws_local['batch_idx'].unique()))

# 4. Summary statistics bar chart
ax4 = plt.subplot(2, 3, 4)
summary_stats = df_aws_local.groupby('prediction_type')['difference'].agg(['mean', 'max', 'std']).reset_index()
x = np.arange(len(summary_stats))
width = 0.25

bars1 = ax4.bar(x - width, summary_stats['mean'] * 1000, width, label='Mean', color='#2E86AB', alpha=0.8)
bars2 = ax4.bar(x, summary_stats['max'] * 1000, width, label='Max', color='#A23B72', alpha=0.8)
bars3 = ax4.bar(x + width, summary_stats['std'] * 1000, width, label='Std', color='#F18F01', alpha=0.8)

ax4.set_ylabel('AUC Difference (×1000)', fontsize=11, fontweight='bold')
ax4.set_title('Summary Statistics', fontsize=12, fontweight='bold')
ax4.set_xticks(x)
ax4.set_xticklabels(summary_stats['prediction_type'], rotation=15, ha='right')
ax4.legend(fontsize=10)
ax4.grid(axis='y', alpha=0.3, linestyle='--')

# Add value labels
for bars in [bars1, bars2, bars3]:
    for bar in bars:
        height = bar.get_height()
        if height > 0:
            ax4.text(bar.get_x() + bar.get_width()/2., height,
                    f'{height:.2f}',
                    ha='center', va='bottom', fontsize=9)

# 5. Percentage within thresholds
ax5 = plt.subplot(2, 3, 5)
thresholds = [0.001, 0.005, 0.01, 0.05]
threshold_data = []
for pred_type in df_aws_local['prediction_type'].unique():
    pred_diffs = df_aws_local[df_aws_local['prediction_type'] == pred_type]['difference']
    total = len(pred_diffs)
    for thresh in thresholds:
        within = (pred_diffs < thresh).sum()
        threshold_data.append({
            'prediction_type': pred_type,
            'threshold': f'{thresh*1000:.1f}',
            'percentage': within / total * 100
        })

df_thresh = pd.DataFrame(threshold_data)
pivot_thresh = df_thresh.pivot(index='prediction_type', columns='threshold', values='percentage')

sns.heatmap(pivot_thresh, annot=True, fmt='.1f', cmap='YlOrRd', ax=ax5, 
            cbar_kws={'label': '% Within Threshold'}, vmin=0, vmax=100)
ax5.set_xlabel('Threshold (×1000)', fontsize=11, fontweight='bold')
ax5.set_ylabel('Prediction Type', fontsize=11, fontweight='bold')
ax5.set_title('Percentage Within Thresholds', fontsize=12, fontweight='bold')

# 6. Scatter plot: AWS vs Local AUCs
ax6 = plt.subplot(2, 3, 6)
for pred_type in df_aws_local['prediction_type'].unique():
    pred_data = df_aws_local[df_aws_local['prediction_type'] == pred_type]
    ax6.scatter(pred_data['local_auc'], pred_data['aws_auc'], 
               alpha=0.5, s=30, label=pred_type)

# Add diagonal line
min_val = min(df_aws_local['aws_auc'].min(), df_aws_local['local_auc'].min())
max_val = max(df_aws_local['aws_auc'].max(), df_aws_local['local_auc'].max())
ax6.plot([min_val, max_val], [min_val, max_val], 'r--', linewidth=2, alpha=0.7, label='y=x')

ax6.set_xlabel('Local AUC', fontsize=11, fontweight='bold')
ax6.set_ylabel('AWS AUC', fontsize=11, fontweight='bold')
ax6.set_title('AWS vs Local AUCs\n(Perfect match = diagonal)', fontsize=12, fontweight='bold')
ax6.legend(fontsize=9, loc='lower right')
ax6.grid(alpha=0.3, linestyle='--')

plt.suptitle('AWS vs Local Comparison: Platform Consistency', 
             fontsize=14, fontweight='bold', y=0.98)

plt.tight_layout(rect=[0, 0, 1, 0.97])

# Save figure
fig_path_aws_local = '/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/aws_vs_local_comparison_plot.png'
plt.savefig(fig_path_aws_local, dpi=300, bbox_inches='tight')
print(f"✓ Saved figure to: {fig_path_aws_local}")

plt.show()

# Print summary
print("\n" + "="*80)
print("AWS vs LOCAL SUMMARY STATISTICS:")
print("="*80)
for pred_type in df_aws_local['prediction_type'].unique():
    pred_data = df_aws_local[df_aws_local['prediction_type'] == pred_type]
    print(f"\n{pred_type}:")
    print(f"  Mean difference: {pred_data['difference'].mean()*1000:.3f} (×1000)")
    print(f"  Max difference: {pred_data['difference'].max()*1000:.3f} (×1000)")
    print(f"  Std difference: {pred_data['difference'].std()*1000:.3f} (×1000)")
    print(f"  Median difference: {pred_data['difference'].median()*1000:.3f} (×1000)")
    print(f"  Comparisons < 0.001: {(pred_data['difference'] < 0.001).sum()}/{len(pred_data)} ({(pred_data['difference'] < 0.001).sum()/len(pred_data)*100:.1f}%)")
    print(f"  Comparisons < 0.01: {(pred_data['difference'] < 0.01).sum()}/{len(pred_data)} ({(pred_data['difference'] < 0.01).sum()/len(pred_data)*100:.1f}%)")
print("\n" + "="*80)

### do the leave one out



In [None]:
# Compare Leave-One-Out vs Full Pooled results
# For batches that were excluded in LOO validation

# Batches excluded in LOO (from the folder list)
excluded_batches = [0, 6, 15, 17, 18, 20, 24, 34, 35, 37]

# Storage for results
loo_10yr_results = []
loo_30yr_results = []
loo_static_10yr_results = []

full_pooled_10yr_results = []
full_pooled_30yr_results = []
full_pooled_static_10yr_results = []

# Load full tensors once
if 'Y_full' not in globals():
    Y_full = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/Y_tensor.pt')
if 'E_full' not in globals():
    E_full = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/E_enrollment_full.pt')

# Loop through excluded batches
for batch_idx in excluded_batches:
    start_idx = batch_idx * 10000
    end_idx = (batch_idx + 1) * 10000
    
    print(f"\n{'='*80}")
    print(f"Processing batch {batch_idx}: {start_idx} to {end_idx}")
    print(f"{'='*80}")
    
    # Get pce_df subset for this batch
    pce_df_subset = pce_df_full[start_idx:end_idx].copy().reset_index(drop=True)
    
    # Extract batch from full tensors
    Y_batch = Y_full[start_idx:end_idx]
    E_batch = E_full[start_idx:end_idx]
    
    # ===== LEAVE-ONE-OUT RESULTS =====
    loo_ckpt_path = f'/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/leave_one_out_validation/batch_{batch_idx}/model_enroll_fixedphi_sex_{start_idx}_{end_idx}.pt'
    
    try:
        print(f"\n--- Leave-One-Out (excluded batch {batch_idx}) ---")
        loo_ckpt = torch.load(loo_ckpt_path, weights_only=False)
        model.load_state_dict(loo_ckpt['model_state_dict'])
        
        model.Y = torch.tensor(Y_batch, dtype=torch.float32)
        model.N = Y_batch.shape[0]
       
        # 10-year predictions
        print(f"LOO - 10 year predictions...")
        loo_10yr = evaluate_major_diseases_wsex_with_bootstrap_dynamic(
            model, Y_batch, E_batch, disease_names, pce_df_subset, 
            n_bootstraps=100, follow_up_duration_years=10, patient_indices=None
        )
        loo_10yr['batch_idx'] = batch_idx
        loo_10yr['analysis_type'] = 'leave_one_out'
        loo_10yr_results.append(loo_10yr)
        
        # 30-year predictions
        print(f"LOO - 30 year predictions...")
        loo_30yr = evaluate_major_diseases_wsex_with_bootstrap_dynamic(
            model, Y_batch, E_batch, disease_names, pce_df_subset, 
            n_bootstraps=100, follow_up_duration_years=30, patient_indices=None
        )
        loo_30yr['batch_idx'] = batch_idx
        loo_30yr['analysis_type'] = 'leave_one_out'
        loo_30yr_results.append(loo_30yr)
        
        # Static 10-year predictions
        print(f"LOO - Static 10 year predictions...")
        loo_static_10yr = evaluate_major_diseases_wsex_with_bootstrap(
            model=model,
            Y_100k=Y_batch,
            E_100k=E_batch,
            disease_names=disease_names,
            pce_df=pce_df_subset,
            n_bootstraps=100,
            follow_up_duration_years=10,
        )
        loo_static_10yr['batch_idx'] = batch_idx
        loo_static_10yr['analysis_type'] = 'leave_one_out'
        loo_static_10yr_results.append(loo_static_10yr)
        
    except FileNotFoundError:
        print(f"LOO checkpoint not found: {loo_ckpt_path}")
    except Exception as e:
        print(f"Error processing LOO checkpoint {batch_idx}: {e}")
        import traceback
        traceback.print_exc()
    
    # ===== FULL POOLED RESULTS =====
    full_pooled_ckpt_path = f'/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/enrollment_predictions_fixedphi_RETROSPECTIVE_pooled/model_enroll_fixedphi_sex_{start_idx}_{end_idx}.pt'
    
    try:
        print(f"\n--- Full Pooled (all 40 batches) ---")
        full_ckpt = torch.load(full_pooled_ckpt_path, weights_only=False)
        model.load_state_dict(full_ckpt['model_state_dict'])
        
        model.Y = torch.tensor(Y_batch, dtype=torch.float32)
        model.N = Y_batch.shape[0]
       
        # 10-year predictions
        print(f"Full Pooled - 10 year predictions...")
        full_10yr = evaluate_major_diseases_wsex_with_bootstrap_dynamic(
            model, Y_batch, E_batch, disease_names, pce_df_subset, 
            n_bootstraps=100, follow_up_duration_years=10, patient_indices=None
        )
        full_10yr['batch_idx'] = batch_idx
        full_10yr['analysis_type'] = 'full_pooled'
        full_pooled_10yr_results.append(full_10yr)
        
        # 30-year predictions
        print(f"Full Pooled - 30 year predictions...")
        full_30yr = evaluate_major_diseases_wsex_with_bootstrap_dynamic(
            model, Y_batch, E_batch, disease_names, pce_df_subset, 
            n_bootstraps=100, follow_up_duration_years=30, patient_indices=None
        )
        full_30yr['batch_idx'] = batch_idx
        full_30yr['analysis_type'] = 'full_pooled'
        full_pooled_30yr_results.append(full_30yr)
        
        # Static 10-year predictions
        print(f"Full Pooled - Static 10 year predictions...")
        full_static_10yr = evaluate_major_diseases_wsex_with_bootstrap(
            model=model,
            Y_100k=Y_batch,
            E_100k=E_batch,
            disease_names=disease_names,
            pce_df=pce_df_subset,
            n_bootstraps=100,
            follow_up_duration_years=10,
        )
        full_static_10yr['batch_idx'] = batch_idx
        full_static_10yr['analysis_type'] = 'full_pooled'
        full_pooled_static_10yr_results.append(full_static_10yr)
        
    except FileNotFoundError:
        print(f"Full pooled checkpoint not found: {full_pooled_ckpt_path}")
    except Exception as e:
        print(f"Error processing full pooled checkpoint {batch_idx}: {e}")
        import traceback
        traceback.print_exc()

print(f"\n{'='*80}")
print("Completed processing!")
print(f"{'='*80}")
print(f"LOO - 10yr: {len(loo_10yr_results)} batches")
print(f"LOO - 30yr: {len(loo_30yr_results)} batches")
print(f"Full Pooled - 10yr: {len(full_pooled_10yr_results)} batches")
print(f"Full Pooled - 30yr: {len(full_pooled_30yr_results)} batches")

# Extract AUCs and compare
def extract_aucs_from_results(results_list):
    aucs_by_batch = {}
    for result in results_list:
        batch_idx = result['batch_idx']
        if batch_idx not in aucs_by_batch:
            aucs_by_batch[batch_idx] = {}
        for disease, metrics in result.items():
            if disease not in ['batch_idx', 'analysis_type'] and isinstance(metrics, dict):
                if 'auc' in metrics:
                    aucs_by_batch[batch_idx][disease] = metrics['auc']
    return aucs_by_batch

loo_10yr_aucs = extract_aucs_from_results(loo_10yr_results)
loo_30yr_aucs = extract_aucs_from_results(loo_30yr_results)
loo_static_10yr_aucs = extract_aucs_from_results(loo_static_10yr_results)

full_10yr_aucs = extract_aucs_from_results(full_pooled_10yr_results)
full_30yr_aucs = extract_aucs_from_results(full_pooled_30yr_results)
full_static_10yr_aucs = extract_aucs_from_results(full_pooled_static_10yr_results)

# Compare using the same function
compare_results(loo_10yr_aucs, full_10yr_aucs, "LEAVE-ONE-OUT vs FULL POOLED - 10-YEAR PREDICTIONS")
compare_results(loo_30yr_aucs, full_30yr_aucs, "LEAVE-ONE-OUT vs FULL POOLED - 30-YEAR PREDICTIONS")
compare_results(loo_static_10yr_aucs, full_static_10yr_aucs, "LEAVE-ONE-OUT vs FULL POOLED - STATIC 10-YEAR PREDICTIONS")

In [None]:
# ============================================================================
# SAVE LOO VALIDATION RESULTS TO CSV AND CREATE VISUALIZATIONS
# ============================================================================

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Extract all differences from the comparison results
def extract_differences(loo_aucs, full_aucs, prediction_type):
    """Extract differences between LOO and Full Pooled AUCs"""
    all_differences = []
    
    # FIX: Use set() instead of list() for intersection
    for batch_idx in sorted(set(loo_aucs.keys()) & set(full_aucs.keys())):
        loo_batch = loo_aucs.get(batch_idx, {})
        full_batch = full_aucs.get(batch_idx, {})
        
        common_diseases = set(loo_batch.keys()) & set(full_batch.keys())
        
        for disease in common_diseases:
            loo_auc = loo_batch[disease]
            full_auc = full_batch[disease]
            diff = abs(loo_auc - full_auc)
            
            all_differences.append({
                'batch_idx': batch_idx,
                'disease': disease,
                'loo_auc': loo_auc,
                'full_pooled_auc': full_auc,
                'difference': diff,
                'prediction_type': prediction_type
            })
    
    return all_differences

# Extract differences for all prediction types
all_diffs = []
all_diffs.extend(extract_differences(loo_10yr_aucs, full_10yr_aucs, '10-Year'))
all_diffs.extend(extract_differences(loo_30yr_aucs, full_30yr_aucs, '30-Year'))
all_diffs.extend(extract_differences(loo_static_10yr_aucs, full_static_10yr_aucs, 'Static 10-Year'))

# Create DataFrame
df_loo = pd.DataFrame(all_diffs)

# Save to CSV
csv_path = '/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/loo_validation_results.csv'
df_loo.to_csv(csv_path, index=False)
print(f"✓ Saved LOO validation results to: {csv_path}")
print(f"  Total comparisons: {len(df_loo)}")
print(f"  Batches: {sorted(df_loo['batch_idx'].unique())}")
print(f"  Prediction types: {df_loo['prediction_type'].unique()}")

# ============================================================================
# CREATE VISUALIZATIONS
# ============================================================================

fig = plt.figure(figsize=(16, 10))

# 1. Distribution of differences by prediction type
ax1 = plt.subplot(2, 3, 1)
for pred_type in df_loo['prediction_type'].unique():
    diffs = df_loo[df_loo['prediction_type'] == pred_type]['difference'] * 1000
    ax1.hist(diffs, bins=30, alpha=0.6, label=pred_type, edgecolor='black', linewidth=0.5)

ax1.set_xlabel('AUC Difference (×1000)', fontsize=11, fontweight='bold')
ax1.set_ylabel('Frequency', fontsize=11, fontweight='bold')
ax1.set_title('Distribution of Differences\nby Prediction Type', fontsize=12, fontweight='bold')
ax1.legend(fontsize=10)
ax1.grid(alpha=0.3, linestyle='--')

# 2. Box plot of differences by prediction type
ax2 = plt.subplot(2, 3, 2)
df_loo['difference_x1000'] = df_loo['difference'] * 1000
sns.boxplot(data=df_loo, x='prediction_type', y='difference_x1000', ax=ax2, palette='Set2')
ax2.set_xlabel('Prediction Type', fontsize=11, fontweight='bold')
ax2.set_ylabel('AUC Difference (×1000)', fontsize=11, fontweight='bold')
ax2.set_title('Distribution of Differences\n(Box Plot)', fontsize=12, fontweight='bold')
ax2.grid(axis='y', alpha=0.3, linestyle='--')

# 3. Max difference per batch
ax3 = plt.subplot(2, 3, 3)
batch_max_diffs = df_loo.groupby(['batch_idx', 'prediction_type'])['difference'].max().reset_index()
for pred_type in df_loo['prediction_type'].unique():
    batch_data = batch_max_diffs[batch_max_diffs['prediction_type'] == pred_type]
    ax3.plot(batch_data['batch_idx'], batch_data['difference'] * 1000, 
             marker='o', linewidth=2, markersize=8, label=pred_type)

ax3.set_xlabel('Excluded Batch', fontsize=11, fontweight='bold')
ax3.set_ylabel('Max AUC Difference (×1000)', fontsize=11, fontweight='bold')
ax3.set_title('Max Difference per Excluded Batch', fontsize=12, fontweight='bold')
ax3.legend(fontsize=10)
ax3.grid(alpha=0.3, linestyle='--')
ax3.set_xticks(sorted(df_loo['batch_idx'].unique()))

# 4. Summary statistics bar chart
ax4 = plt.subplot(2, 3, 4)
summary_stats = df_loo.groupby('prediction_type')['difference'].agg(['mean', 'max', 'std']).reset_index()
x = np.arange(len(summary_stats))
width = 0.25

bars1 = ax4.bar(x - width, summary_stats['mean'] * 1000, width, label='Mean', color='#2E86AB', alpha=0.8)
bars2 = ax4.bar(x, summary_stats['max'] * 1000, width, label='Max', color='#A23B72', alpha=0.8)
bars3 = ax4.bar(x + width, summary_stats['std'] * 1000, width, label='Std', color='#F18F01', alpha=0.8)

ax4.set_ylabel('AUC Difference (×1000)', fontsize=11, fontweight='bold')
ax4.set_title('Summary Statistics', fontsize=12, fontweight='bold')
ax4.set_xticks(x)
ax4.set_xticklabels(summary_stats['prediction_type'], rotation=15, ha='right')
ax4.legend(fontsize=10)
ax4.grid(axis='y', alpha=0.3, linestyle='--')

# Add value labels
for bars in [bars1, bars2, bars3]:
    for bar in bars:
        height = bar.get_height()
        if height > 0:
            ax4.text(bar.get_x() + bar.get_width()/2., height,
                    f'{height:.2f}',
                    ha='center', va='bottom', fontsize=9)

# 5. Percentage within thresholds
ax5 = plt.subplot(2, 3, 5)
thresholds = [0.001, 0.005, 0.01, 0.05]
threshold_data = []
for pred_type in df_loo['prediction_type'].unique():
    pred_diffs = df_loo[df_loo['prediction_type'] == pred_type]['difference']
    total = len(pred_diffs)
    for thresh in thresholds:
        within = (pred_diffs < thresh).sum()
        threshold_data.append({
            'prediction_type': pred_type,
            'threshold': f'{thresh*1000:.1f}',
            'percentage': within / total * 100
        })

df_thresh = pd.DataFrame(threshold_data)
pivot_thresh = df_thresh.pivot(index='prediction_type', columns='threshold', values='percentage')

sns.heatmap(pivot_thresh, annot=True, fmt='.1f', cmap='YlGnBu', ax=ax5, 
            cbar_kws={'label': '% Within Threshold'}, vmin=95, vmax=100)
ax5.set_xlabel('Threshold (×1000)', fontsize=11, fontweight='bold')
ax5.set_ylabel('Prediction Type', fontsize=11, fontweight='bold')
ax5.set_title('Percentage Within Thresholds', fontsize=12, fontweight='bold')

# 6. Scatter plot: LOO vs Full Pooled AUCs
ax6 = plt.subplot(2, 3, 6)
for pred_type in df_loo['prediction_type'].unique():
    pred_data = df_loo[df_loo['prediction_type'] == pred_type]
    ax6.scatter(pred_data['full_pooled_auc'], pred_data['loo_auc'], 
               alpha=0.5, s=30, label=pred_type)

# Add diagonal line
min_val = min(df_loo['loo_auc'].min(), df_loo['full_pooled_auc'].min())
max_val = max(df_loo['loo_auc'].max(), df_loo['full_pooled_auc'].max())
ax6.plot([min_val, max_val], [min_val, max_val], 'r--', linewidth=2, alpha=0.7, label='y=x')

ax6.set_xlabel('Full Pooled AUC', fontsize=11, fontweight='bold')
ax6.set_ylabel('Leave-One-Out AUC', fontsize=11, fontweight='bold')
ax6.set_title('LOO vs Full Pooled AUCs\n(Perfect match = diagonal)', fontsize=12, fontweight='bold')
ax6.legend(fontsize=9, loc='lower right')
ax6.grid(alpha=0.3, linestyle='--')

plt.suptitle('Leave-One-Out Validation: Robustness of Pooled Phi Approach', 
             fontsize=14, fontweight='bold', y=0.98)

plt.tight_layout(rect=[0, 0, 1, 0.97])

# Save figure
fig_path = '/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/loo_validation_plot.png'
plt.savefig(fig_path, dpi=300, bbox_inches='tight')
print(f"✓ Saved figure to: {fig_path}")

plt.show()

# Print summary
print("\n" + "="*80)
print("SUMMARY STATISTICS:")
print("="*80)
for pred_type in df_loo['prediction_type'].unique():
    pred_data = df_loo[df_loo['prediction_type'] == pred_type]
    print(f"\n{pred_type}:")
    print(f"  Mean difference: {pred_data['difference'].mean()*1000:.3f} (×1000)")
    print(f"  Max difference: {pred_data['difference'].max()*1000:.3f} (×1000)")
    print(f"  Std difference: {pred_data['difference'].std()*1000:.3f} (×1000)")
    print(f"  Median difference: {pred_data['difference'].median()*1000:.3f} (×1000)")
    print(f"  Comparisons < 0.001: {(pred_data['difference'] < 0.001).sum()}/{len(pred_data)} ({(pred_data['difference'] < 0.001).sum()/len(pred_data)*100:.1f}%)")
    print(f"  Comparisons < 0.01: {(pred_data['difference'] < 0.01).sum()}/{len(pred_data)} ({(pred_data['difference'] < 0.01).sum()/len(pred_data)*100:.1f}%)")
print("\n" + "="*80)

In [None]:
from fig5utils import *
import pandas as pd
import numpy as np
import torch

# Load full pce_df
pce_df_full = pd.read_csv('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/pce_prevent_full.csv')
disease_names = essentials['disease_names']

# Storage for results - SEPARATE variables for each analysis type
# Fixed phi from ENROLLMENT data
fixed_enrollment_10yr_results = []
fixed_enrollment_30yr_results = []
fixed_enrollment_static_10yr_results = []

# Fixed phi from RETROSPECTIVE data
fixed_retrospective_10yr_results = []
fixed_retrospective_30yr_results = []
fixed_retrospective_static_10yr_results = []

# Load full tensors once (shared across both analyses)
if 'Y_full' not in globals():
    Y_full = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/Y_tensor.pt')
if 'E_full' not in globals():
    E_full = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/E_enrollment_full.pt')

# Loop through checkpoints 0-10 (10 batches)
for batch_idx in range(41):
    start_idx = batch_idx * 10000
    end_idx = (batch_idx + 1) * 10000
    
    print(f"\n{'='*80}")
    print(f"Processing batch {batch_idx}: {start_idx} to {end_idx}")
    print(f"{'='*80}")
    
    # Get pce_df subset for this batch
    pce_df_subset = pce_df_full[start_idx:end_idx].copy().reset_index(drop=True)
    
    # Extract batch from full tensors (shared for both analyses)
    Y_batch = Y_full[start_idx:end_idx]
    E_batch = E_full[start_idx:end_idx]
    
    # ===== FIXED PHI FROM ENROLLMENT DATA =====
    fixed_enrollment_ckpt_path = f'/Users/sarahurbut/Library/CloudStorage/Dropbox/enrollment_predictions_fixedphi_ENROLLMENT_pooled/model_enroll_fixedphi_sex_{start_idx}_{end_idx}.pt'
    
    try:
        print(f"\n--- Fixed Phi (ENROLLMENT) ---")
        fixed_ckpt = torch.load(fixed_enrollment_ckpt_path, weights_only=False)
        model.load_state_dict(fixed_ckpt['model_state_dict'])
        
        # Update model.Y and model.N so forward() uses correct patients
        model.Y = torch.tensor(Y_batch, dtype=torch.float32)
        model.N = Y_batch.shape[0]
       
        # 10-year predictions
        print(f"Fixed Phi (ENROLLMENT) - 10 year predictions...")
        fixed_10yr = evaluate_major_diseases_wsex_with_bootstrap_dynamic(
            model, Y_batch, E_batch, disease_names, pce_df_subset, 
            n_bootstraps=100, follow_up_duration_years=10, patient_indices=None
        )
        fixed_10yr['batch_idx'] = batch_idx
        fixed_10yr['analysis_type'] = 'fixed_enrollment'
        fixed_enrollment_10yr_results.append(fixed_10yr)
        
        # 30-year predictions
        print(f"Fixed Phi (ENROLLMENT) - 30 year predictions...")
        fixed_30yr = evaluate_major_diseases_wsex_with_bootstrap_dynamic(
            model, Y_batch, E_batch, disease_names, pce_df_subset, 
            n_bootstraps=100, follow_up_duration_years=30, patient_indices=None
        )
        fixed_30yr['batch_idx'] = batch_idx
        fixed_30yr['analysis_type'] = 'fixed_enrollment'
        fixed_enrollment_30yr_results.append(fixed_30yr)
        
        # Static 10-year predictions (using 1-year score)
        print(f"Fixed Phi (ENROLLMENT) - Static 10 year predictions...")
        fixed_static_10yr = evaluate_major_diseases_wsex_with_bootstrap(
            model=model,
            Y_100k=Y_batch,
            E_100k=E_batch,
            disease_names=disease_names,
            pce_df=pce_df_subset,
            n_bootstraps=100,
            follow_up_duration_years=10,
        )
        fixed_static_10yr['batch_idx'] = batch_idx
        fixed_static_10yr['analysis_type'] = 'fixed_enrollment'
        fixed_enrollment_static_10yr_results.append(fixed_static_10yr)
        
    except FileNotFoundError:
        print(f"Fixed phi (ENROLLMENT) checkpoint not found: {fixed_enrollment_ckpt_path}")
    except Exception as e:
        print(f"Error processing fixed phi (ENROLLMENT) checkpoint {batch_idx}: {e}")
        import traceback
        traceback.print_exc()
    
    # ===== FIXED PHI FROM RETROSPECTIVE DATA =====
    fixed_retrospective_ckpt_path = f'/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/enrollment_predictions_fixedphi_RETROSPECTIVE_pooled/model_enroll_fixedphi_sex_{start_idx}_{end_idx}.pt'
    
    try:
        print(f"\n--- Fixed Phi (RETROSPECTIVE) ---")
        fixed_ckpt = torch.load(fixed_retrospective_ckpt_path, weights_only=False)
        model.load_state_dict(fixed_ckpt['model_state_dict'])
        
        # Update model.Y and model.N so forward() uses correct patients
        model.Y = torch.tensor(Y_batch, dtype=torch.float32)
        model.N = Y_batch.shape[0]
       
        # 10-year predictions
        print(f"Fixed Phi (RETROSPECTIVE) - 10 year predictions...")
        fixed_10yr = evaluate_major_diseases_wsex_with_bootstrap_dynamic(
            model, Y_batch, E_batch, disease_names, pce_df_subset, 
            n_bootstraps=100, follow_up_duration_years=10, patient_indices=None
        )
        fixed_10yr['batch_idx'] = batch_idx
        fixed_10yr['analysis_type'] = 'fixed_retrospective'
        fixed_retrospective_10yr_results.append(fixed_10yr)
        
        # 30-year predictions
        print(f"Fixed Phi (RETROSPECTIVE) - 30 year predictions...")
        fixed_30yr = evaluate_major_diseases_wsex_with_bootstrap_dynamic(
            model, Y_batch, E_batch, disease_names, pce_df_subset, 
            n_bootstraps=100, follow_up_duration_years=30, patient_indices=None
        )
        fixed_30yr['batch_idx'] = batch_idx
        fixed_30yr['analysis_type'] = 'fixed_retrospective'
        fixed_retrospective_30yr_results.append(fixed_30yr)
        
        # Static 10-year predictions (using 1-year score)
        print(f"Fixed Phi (RETROSPECTIVE) - Static 10 year predictions...")
        fixed_static_10yr = evaluate_major_diseases_wsex_with_bootstrap(
            model=model,
            Y_100k=Y_batch,
            E_100k=E_batch,
            disease_names=disease_names,
            pce_df=pce_df_subset,
            n_bootstraps=100,
            follow_up_duration_years=10,
        )
        fixed_static_10yr['batch_idx'] = batch_idx
        fixed_static_10yr['analysis_type'] = 'fixed_retrospective'
        fixed_retrospective_static_10yr_results.append(fixed_static_10yr)
        
    except FileNotFoundError:
        print(f"Fixed phi (RETROSPECTIVE) checkpoint not found: {fixed_retrospective_ckpt_path}")
    except Exception as e:
        print(f"Error processing fixed phi (RETROSPECTIVE) checkpoint {batch_idx}: {e}")
        import traceback
        traceback.print_exc()

print(f"\n{'='*80}")
print("Completed processing all checkpoints!")
print(f"{'='*80}")
print(f"Fixed Enrollment - 10yr: {len(fixed_enrollment_10yr_results)} batches")
print(f"Fixed Enrollment - 30yr: {len(fixed_enrollment_30yr_results)} batches")
print(f"Fixed Retrospective - 10yr: {len(fixed_retrospective_10yr_results)} batches")
print(f"Fixed Retrospective - 30yr: {len(fixed_retrospective_30yr_results)} batches")


In [None]:

# Get predictions (pi) from the model

ckpt_0_10000=torch.load("/Users/sarahurbut/Library/CloudStorage/Dropbox/enrollment_prediction_jointphi_sex_pcs/enrollment_model_W0.0001_batch_0_10000.pt")

ckpt_test=model.load_state_dict(ckpt_0_10000['model_state_dict'])
with torch.no_grad():
    pi, _, _ = model.forward()  # pi shape: (N, D, T)

torch.save(pi,"/Users/sarahurbut/Library/Cloudstorage/Dropbox/enrollment_prediction_jointphi_sex_pcs/pi_enroll_sex_0_10000.pt")
del pi
####

ckpt_20000_30000=torch.load("/Users/sarahurbut/Library/CloudStorage/Dropbox/enrollment_prediction_jointphi_sex_pcs/enrollment_model_W0.0001_batch_20000_30000.pt")

ckpt_test=model.load_state_dict(ckpt_20000_30000['model_state_dict'])
with torch.no_grad():
    pi, _, _ = model.forward()  # pi shape: (N, D, T)

torch.save(pi,"/Users/sarahurbut/Library/Cloudstorage/Dropbox/enrollment_prediction_jointphi_sex_pcs/pi_enroll_sex_20000_30000.pt")



# Gold standard

In [None]:
from fig5utils import *
disease_names=essentials['disease_names']
ckpt=torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/enrollment_retrospective_full/enrollment_model_W0.0001_batch_0_10000.pt')
model.load_state_dict(ckpt['model_state_dict'])

evaluate_major_diseases_wsex_with_bootstrap_dynamic(model, Y_100k, E_100k, disease_names, pce_df, n_bootstraps=100, follow_up_duration_years=30, patient_indices=None)

# fixed phi from enrollment

In this we used the fixed data that was estimated from the enrollment_model_W0.0001_fulldata_sexspecific.pt and the fixed phi 

In [None]:

from fig5utils import *
disease_names=essentials['disease_names']
ckpt=torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/enrollment_predictions_fixedphi_withpcs/output/model_enroll_fixedphi_sex_0_10000.pt')
model.load_state_dict(ckpt['model_state_dict'])

evaluate_major_diseases_wsex_with_bootstrap_dynamic(model, Y_100k, E_100k, disease_names, pce_df, n_bootstraps=100, follow_up_duration_years=30, patient_indices=None)

In [None]:
from utils import *
results = evaluate_major_diseases_wsex_with_bootstrap(
    model=model,
    Y_100k=Y_100k,
    E_100k=E_100k,
    disease_names=disease_names,
    pce_df=pce_df,
    n_bootstraps=20,
    follow_up_duration_years=10,
)

# joint 10 year

In [None]:
from fig5utils import *
disease_names=essentials['disease_names']
ckpt=torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/enrollment_prediction_jointphi_sex_pcs/enrollment_model_W0.0001_batch_0_10000.pt')
model.load_state_dict(ckpt['model_state_dict'])

evaluate_major_diseases_wsex_with_bootstrap_dynamic(model, ckpt['Y'], E_100k, disease_names, pce_df, n_bootstraps=100, follow_up_duration_years=30, patient_indices=None)

In [None]:
ckpt.keys()

In [None]:
from fig5utils import *
disease_names=essentials['disease_names']
ckpt=torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/enrollment_predictions_fixedphi_ENROLLMENT_pooled/model_enroll_fixedphi_sex_0_10000.pt')
model.load_state_dict(ckpt['model_state_dict'])

evaluate_major_diseases_wsex_with_bootstrap_dynamic(model, Y_100k, E_100k, disease_names, pce_df, n_bootstraps=100, follow_up_duration_years=30, patient_indices=None)

In [None]:
results = evaluate_major_diseases_wsex_with_bootstrap(
    model=model,
    Y_100k=Y_100k,
    E_100k=E_100k,
    disease_names=disease_names,
    pce_df=pce_df,
    n_bootstraps=20,
    follow_up_duration_years=10,
)

In [None]:
from fig5utils import *
disease_names=essentials['disease_names']
ckpt=torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/enrollment_prediction_jointphi_sex_pcs/enrollment_model_W0.0001_batch_20000_30000.pt')
Y200k=ckpt['Y']
model.load_state_dict(ckpt['model_state_dict'])

evaluate_major_diseases_wsex_with_bootstrap_dynamic(model, Y200k, E_100k, disease_names, pce_df, n_bootstraps=100, follow_up_duration_years=30, patient_indices=None)

In [None]:
ckpt=torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/enrollment_predictions_fixedphi_ENROLLMENT_pooled/model_enroll_fixedphi_sex_20000_30000.pt')
model.load_state_dict(ckpt['model_state_dict'])

evaluate_major_diseases_wsex_with_bootstrap_dynamic(model, Y200k, E_100k, disease_names, pce_df, n_bootstraps=100, follow_up_duration_years=30, patient_indices=None)

## check joint (we know this was run with enrollment)

In [None]:
ckpt=torch.load('/Users/sarahurbut/Library/CloudStorage/DB_backup_5132025941p/enrollment_model_W0.0001_jointphi_sexspecific.pt')
ckpt['G'].shape


In [None]:

G_with_sex = ckpt['G']  # sex should be numeric (e.g., 0/1)
# N


model = AladynSurvivalFixedKernelsAvgLoss_clust_logitInit_psitest(
    N=Y_100k.shape[0], 
    D=Y_100k.shape[1], 
    T=Y_100k.shape[2], 
    K=20,
    P=G_with_sex.shape[1],
    init_sd_scaler=1e-1,
    G=G_with_sex, 
    Y=Y_100k,
    genetic_scale=1,
    W=0,
    R=0,
    prevalence_t=essentials['prevalence_t'],
    signature_references=signature_refs,  # Only pass signature refs
    healthy_reference=True,  # Explicitly set to None
    disease_names=essentials['disease_names']
)

model.load_state_dict(ckpt['model_state_dict'])

evaluate_major_diseases_wsex_with_bootstrap_dynamic(model, Y_100k, E_100k, disease_names, pce_df, n_bootstraps=100, follow_up_duration_years=30, patient_indices=None)

## joint fun

In [None]:
from fig5utils import *
pce_df_full=pd.read_csv('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/pce_prevent_full.csv')
pce_df_subset=pce_df_full[0:10000]
disease_names=essentials['disease_names']
ckpt=torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/enrollment_prediction_jointphi_sex_pcs/enrollment_model_W0.0001_batch_0_10000.pt')
model.load_state_dict(ckpt['model_state_dict'])

evaluate_major_diseases_wsex_with_bootstrap_dynamic(model, ckpt['Y'], E_100k, disease_names, pce_df_subset, n_bootstraps=100, follow_up_duration_years=30, patient_indices=None)

In [None]:
from fig5utils import *
pce_df_full=pd.read_csv('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/pce_prevent_full.csv')
pce_df_subset=pce_df_full[10000:20000]
disease_names=essentials['disease_names']
ckpt=torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/enrollment_prediction_jointphi_sex_pcs/enrollment_model_W0.0001_batch_10000_20000.pt')
model.load_state_dict(ckpt['model_state_dict'])

evaluate_major_diseases_wsex_with_bootstrap_dynamic(model, ckpt['Y'], E_100k, disease_names, pce_df_subset, n_bootstraps=100, follow_up_duration_years=30, patient_indices=None)

In [None]:
from fig5utils import *
pce_df_full=pd.read_csv('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/pce_prevent_full.csv')
pce_df_subset=pce_df_full[20000:30000]
disease_names=essentials['disease_names']
ckpt=torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/enrollment_prediction_jointphi_sex_pcs/enrollment_model_W0.0001_batch_20000_30000.pt')
model.load_state_dict(ckpt['model_state_dict'])

evaluate_major_diseases_wsex_with_bootstrap_dynamic(model, ckpt['Y'], E_100k, disease_names, pce_df_subset, n_bootstraps=100, follow_up_duration_years=30, patient_indices=None)

In [None]:
from fig5utils import *
pce_df_subset=pce_df_full[330000:340000]
disease_names=essentials['disease_names']
ckpt=torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/enrollment_prediction_jointphi_sex_pcs/enrollment_model_W0.0001_batch_330000_340000.pt')
model.load_state_dict(ckpt['model_state_dict'])

evaluate_major_diseases_wsex_with_bootstrap_dynamic(model, ckpt['Y'], E_100k, disease_names, pce_df_subset, n_bootstraps=100, follow_up_duration_years=30, patient_indices=None)

# do it all without washout, this is using:

* the fixed phi (estimated on retorspective one batch :) 
* the joint enrollment (where phi is )

In [None]:
from fig5utils import *
import pandas as pd
import numpy as np
import torch

# Load full pce_df
pce_df_full = pd.read_csv('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/pce_prevent_full.csv')
disease_names = essentials['disease_names']

# Storage for results - Dynamic predictions
joint_10yr_results = []
joint_30yr_results = []
fixed_10yr_results = []
fixed_30yr_results = []

# Storage for results - Static predictions (1-year score for 10-year outcome)
joint_static_10yr_results = []
fixed_static_10yr_results = []

# Loop through checkpoints 0-40 (batch_0_10000 to batch_390000_400000)
for batch_idx in range(41):
    start_idx = batch_idx * 10000
    end_idx = (batch_idx + 1) * 10000
    
    print(f"\n{'='*80}")
    print(f"Processing batch {batch_idx}: {start_idx} to {end_idx}")
    print(f"{'='*80}")
    
    # Get pce_df subset for this batch
    pce_df_subset = pce_df_full[start_idx:end_idx].copy().reset_index(drop=True)
    
    # === JOINT PHI CHECKPOINTS ===
    joint_ckpt_path = f'/Users/sarahurbut/Library/CloudStorage/Dropbox/enrollment_prediction_jointphi_sex_pcs/enrollment_model_W0.0001_batch_{start_idx}_{end_idx}.pt'
    
    try:
        joint_ckpt = torch.load(joint_ckpt_path, weights_only=False)
        model.load_state_dict(joint_ckpt['model_state_dict'])
        
        # Use Y from checkpoint and update model.Y so forward() uses correct patients
        Y_batch = joint_ckpt['Y']
        model.Y = torch.tensor(Y_batch, dtype=torch.float32)
        model.N = Y_batch.shape[0]  # Update N to match new Y size
        
        # 10-year predictions
        print(f"\nJoint Phi - 10 year predictions...")
        joint_10yr = evaluate_major_diseases_wsex_with_bootstrap_dynamic(
            model, Y_batch, E_100k, disease_names, pce_df_subset, 
            n_bootstraps=100, follow_up_duration_years=10, patient_indices=None
        )
        joint_10yr['batch_idx'] = batch_idx
        joint_10yr_results.append(joint_10yr)
        
        # 30-year predictions
        print(f"\nJoint Phi - 30 year predictions...")
        joint_30yr = evaluate_major_diseases_wsex_with_bootstrap_dynamic(
            model, Y_batch, E_100k, disease_names, pce_df_subset, 
            n_bootstraps=100, follow_up_duration_years=30, patient_indices=None
        )
        joint_30yr['batch_idx'] = batch_idx
        joint_30yr_results.append(joint_30yr)
        
        # Static 10-year predictions (using 1-year score)
        print(f"\nJoint Phi - Static 10 year predictions (1-year score)...")
        joint_static_10yr = evaluate_major_diseases_wsex_with_bootstrap(
            model=model,
            Y_100k=Y_batch,
            E_100k=E_100k,
            disease_names=disease_names,
            pce_df=pce_df_subset,
            n_bootstraps=100,
            follow_up_duration_years=10,
        )
        joint_static_10yr['batch_idx'] = batch_idx
        joint_static_10yr_results.append(joint_static_10yr)
        
    except FileNotFoundError:
        print(f"Joint phi checkpoint not found: {joint_ckpt_path}")
        continue
    except Exception as e:
        print(f"Error processing joint phi checkpoint {batch_idx}: {e}")
        continue
    
    # === FIXED PHI CHECKPOINTS ===
    fixed_ckpt_path = f'/Users/sarahurbut/Library/CloudStorage/Dropbox/enrollment_predictions_fixedphi_withpcs_fromclaudeoutput/model_enroll_fixedphi_sex_{start_idx}_{end_idx}.pt'
    
    try:
        fixed_ckpt = torch.load(fixed_ckpt_path, weights_only=False)
        model.load_state_dict(fixed_ckpt['model_state_dict'])
        
        # For fixed phi, Y is not saved in checkpoint, so extract from full Y tensor
        # Load full Y tensor if not already available
        if 'Y_full' not in globals():
            Y_full = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/Y_tensor.pt')
        if 'E_full' not in globals():
            E_full = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/E_enrollment_full.pt')
        
        # Extract batch from full tensors
        Y_batch = Y_full[start_idx:end_idx]
        E_batch = E_full[start_idx:end_idx]
        
        # Update model.Y and model.N so forward() uses correct patients
        model.Y = torch.tensor(Y_batch, dtype=torch.float32)
        model.N = Y_batch.shape[0]  # Update N to match new Y size
       
        # 10-year predictions
        print(f"\nFixed Phi - 10 year predictions...")
        fixed_10yr = evaluate_major_diseases_wsex_with_bootstrap_dynamic(
            model, Y_batch, E_batch, disease_names, pce_df_subset, 
            n_bootstraps=100, follow_up_duration_years=10, patient_indices=None
        )
        fixed_10yr['batch_idx'] = batch_idx
        fixed_10yr_results.append(fixed_10yr)
        
        # 30-year predictions
        print(f"\nFixed Phi - 30 year predictions...")
        fixed_30yr = evaluate_major_diseases_wsex_with_bootstrap_dynamic(
            model, Y_batch, E_batch, disease_names, pce_df_subset, 
            n_bootstraps=100, follow_up_duration_years=30, patient_indices=None
        )
        fixed_30yr['batch_idx'] = batch_idx
        fixed_30yr_results.append(fixed_30yr)
        
        # Static 10-year predictions (using 1-year score)
        print(f"\nFixed Phi - Static 10 year predictions (1-year score)...")
        fixed_static_10yr = evaluate_major_diseases_wsex_with_bootstrap(
            model=model,
            Y_100k=Y_batch,
            E_100k=E_batch,
            disease_names=disease_names,
            pce_df=pce_df_subset,
            n_bootstraps=100,
            follow_up_duration_years=10,
        )
        fixed_static_10yr['batch_idx'] = batch_idx
        fixed_static_10yr_results.append(fixed_static_10yr)
        
    except FileNotFoundError:
        print(f"Fixed phi checkpoint not found: {fixed_ckpt_path}")
        continue
    except Exception as e:
        print(f"Error processing fixed phi checkpoint {batch_idx}: {e}")
        continue

print(f"\n{'='*80}")
print("Completed processing all checkpoints!")
print(f"{'='*80}")


In [None]:
# ===== AGGREGATE AND SAVE RESULTS FOR JOINT AND OLD FIXED APPROACHES =====
# Paste this cell after your loop completes (after line 147 in lifetime.ipynb)

print(f"\n{'='*80}")
print("AGGREGATING RESULTS ACROSS ALL BATCHES (JOINT & OLD FIXED)")
print(f"{'='*80}")

def aggregate_results_to_dataframe(results_list, analysis_name):
    """
    Aggregate results across batches into a DataFrame.
    Each result is a dict with disease names as keys and metrics as values.
    """
    if not results_list:
        print(f"Warning: No results found for {analysis_name}")
        return pd.DataFrame()
    
    # Get all disease names (excluding metadata keys)
    disease_names_list = [k for k in results_list[0].keys() 
                         if k not in ['batch_idx', 'analysis_type']]
    
    # Collect all metrics across batches
    aggregated_data = []
    for disease in disease_names_list:
        aucs = []
        ci_lowers = []
        ci_uppers = []
        n_events_list = []
        event_rates = []
        
        for result in results_list:
            if disease in result and isinstance(result[disease], dict):
                if 'auc' in result[disease] and not np.isnan(result[disease]['auc']):
                    aucs.append(result[disease]['auc'])
                if 'ci_lower' in result[disease] and not np.isnan(result[disease]['ci_lower']):
                    ci_lowers.append(result[disease]['ci_lower'])
                if 'ci_upper' in result[disease] and not np.isnan(result[disease]['ci_upper']):
                    ci_uppers.append(result[disease]['ci_upper'])
                if 'n_events' in result[disease]:
                    n_events_list.append(result[disease]['n_events'])
                if 'event_rate' in result[disease] and result[disease]['event_rate'] is not None:
                    event_rates.append(result[disease]['event_rate'])
        
        if aucs:  # Only add if we have at least one valid AUC
            aggregated_data.append({
                'Disease': disease,
                'AUC_median': np.median(aucs),
                'AUC_mean': np.mean(aucs),
                'AUC_std': np.std(aucs),
                'AUC_min': np.min(aucs),
                'AUC_max': np.max(aucs),
                'CI_lower_median': np.median(ci_lowers) if ci_lowers else np.nan,
                'CI_upper_median': np.median(ci_uppers) if ci_uppers else np.nan,
                'CI_lower_min': np.min(ci_lowers) if ci_lowers else np.nan,
                'CI_upper_max': np.max(ci_uppers) if ci_uppers else np.nan,
                'Total_Events': np.sum(n_events_list) if n_events_list else np.nan,
                'Mean_Event_Rate': np.mean(event_rates) if event_rates else np.nan,
                'N_Batches': len(aucs)
            })
    
    df = pd.DataFrame(aggregated_data)
    if not df.empty:
        df = df.set_index('Disease').sort_values('AUC_median', ascending=False)
    return df

# Aggregate all result lists
print("\nAggregating Joint Phi results...")
joint_10yr_df = aggregate_results_to_dataframe(joint_10yr_results, "Joint 10yr")
joint_30yr_df = aggregate_results_to_dataframe(joint_30yr_results, "Joint 30yr")
joint_static_10yr_df = aggregate_results_to_dataframe(joint_static_10yr_results, "Joint Static 10yr")

print("Aggregating Old Fixed Phi results...")
fixed_10yr_df = aggregate_results_to_dataframe(fixed_10yr_results, "Fixed 10yr")
fixed_30yr_df = aggregate_results_to_dataframe(fixed_30yr_results, "Fixed 30yr")
fixed_static_10yr_df = aggregate_results_to_dataframe(fixed_static_10yr_results, "Fixed Static 10yr")

# Save individual DataFrames
output_dir = '/Users/sarahurbut/aladynoulli2/pyScripts/new_oct_revision/new_notebooks/'
print(f"\nSaving aggregated results to {output_dir}...")

joint_10yr_df.to_csv(f'{output_dir}pooled_joint_10yr.csv')
joint_30yr_df.to_csv(f'{output_dir}pooled_joint_30yr.csv')
joint_static_10yr_df.to_csv(f'{output_dir}pooled_joint_static_10yr.csv')

fixed_10yr_df.to_csv(f'{output_dir}pooled_old_fixed_10yr.csv')
fixed_30yr_df.to_csv(f'{output_dir}pooled_old_fixed_30yr.csv')
fixed_static_10yr_df.to_csv(f'{output_dir}pooled_old_fixed_static_10yr.csv')

print("✓ Saved individual result files")

# Create a combined comparison DataFrame
print("\nCreating combined comparison DataFrame...")
all_diseases = set()
for df in [joint_10yr_df, joint_30yr_df, joint_static_10yr_df, 
           fixed_10yr_df, fixed_30yr_df, fixed_static_10yr_df]:
    if not df.empty:
        all_diseases.update(df.index)

comparison_df = pd.DataFrame(index=sorted(all_diseases))
comparison_df['Joint_10yr'] = joint_10yr_df['AUC_median']
comparison_df['Joint_30yr'] = joint_30yr_df['AUC_median']
comparison_df['Joint_Static_10yr'] = joint_static_10yr_df['AUC_median']
comparison_df['Old_Fixed_10yr'] = fixed_10yr_df['AUC_median']
comparison_df['Old_Fixed_30yr'] = fixed_30yr_df['AUC_median']
comparison_df['Old_Fixed_Static_10yr'] = fixed_static_10yr_df['AUC_median']

comparison_df.to_csv(f'{output_dir}pooled_joint_and_old_fixed_comparison.csv')
print("✓ Saved combined comparison file: pooled_joint_and_old_fixed_comparison.csv")

# Print summary
print(f"\n{'='*80}")
print("SUMMARY OF AGGREGATED RESULTS")
print(f"{'='*80}")
print(f"\nJoint - 10yr: {len(joint_10yr_df)} diseases")
print(f"Joint - 30yr: {len(joint_30yr_df)} diseases")
print(f"Joint - Static 10yr: {len(joint_static_10yr_df)} diseases")
print(f"Old Fixed - 10yr: {len(fixed_10yr_df)} diseases")
print(f"Old Fixed - 30yr: {len(fixed_30yr_df)} diseases")
print(f"Old Fixed - Static 10yr: {len(fixed_static_10yr_df)} diseases")

print(f"\n{'='*80}")
print("TOP 10 DISEASES BY AUC (Joint 10yr)")
print(f"{'='*80}")
if not joint_10yr_df.empty:
    print(joint_10yr_df[['AUC_median', 'CI_lower_median', 'CI_upper_median', 'N_Batches']].head(10).round(4))

print(f"\n{'='*80}")
print("TOP 10 DISEASES BY AUC (Old Fixed 10yr)")
print(f"{'='*80}")
if not fixed_10yr_df.empty:
    print(fixed_10yr_df[['AUC_median', 'CI_lower_median', 'CI_upper_median', 'N_Batches']].head(10).round(4))

print(f"\n{'='*80}")
print("COMPARISON: Joint vs Old Fixed (Static 10yr)")
print(f"{'='*80}")
if not joint_static_10yr_df.empty and not fixed_static_10yr_df.empty:
    static_comparison = pd.DataFrame({
        'Joint_Static_10yr': joint_static_10yr_df['AUC_median'],
        'Old_Fixed_Static_10yr': fixed_static_10yr_df['AUC_median'],
        'Difference': fixed_static_10yr_df['AUC_median'] - joint_static_10yr_df['AUC_median']
    }).sort_values('Difference', ascending=False)
    print(static_comparison.round(4))
    print(f"\nMean difference (Old Fixed - Joint): {static_comparison['Difference'].mean():.4f}")
    print(f"Diseases where Old Fixed > Joint: {(static_comparison['Difference'] > 0).sum()} / {len(static_comparison)}")

print(f"\n{'='*80}")
print("All results saved successfully!")
print(f"{'='*80}")


In [None]:
# Rerun only Fixed Phi evaluations with proper Y batch extraction
# (Joint phi results should already be in joint_*_results variables)

# Clear only fixed phi results to rerun them
fixed_10yr_results = []
fixed_30yr_results = []
fixed_static_10yr_results = []

# Load full Y and E tensors once
print("Loading full Y and E tensors...")
Y_full = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/Y_tensor.pt')
E_full = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/E_enrollment_full.pt')
print("Loaded full tensors!")

# Loop through checkpoints 0-40 (batch_0_10000 to batch_390000_400000)
for batch_idx in range(41):
    start_idx = batch_idx * 10000
    end_idx = (batch_idx + 1) * 10000
    
    print(f"\n{'='*80}")
    print(f"Processing FIXED PHI batch {batch_idx}: {start_idx} to {end_idx}")
    print(f"{'='*80}")
    
    # Get pce_df subset for this batch
    pce_df_subset = pce_df_full[start_idx:end_idx].copy().reset_index(drop=True)
    
    # === FIXED PHI CHECKPOINTS ===
    fixed_ckpt_path = f'/Users/sarahurbut/Library/CloudStorage/Dropbox/enrollment_predictions_fixedphi_withpcs_fromclaudeoutput/model_enroll_fixedphi_sex_{start_idx}_{end_idx}.pt'
    
    try:
        fixed_ckpt = torch.load(fixed_ckpt_path, weights_only=False)
        model.load_state_dict(fixed_ckpt['model_state_dict'])
        
        # For fixed phi, Y is not saved in checkpoint, so extract from full Y tensor
        # Extract batch from full tensors
        Y_batch = Y_full[start_idx:end_idx]
        E_batch = E_full[start_idx:end_idx]
        
        # Update model.Y and model.N so forward() uses correct patients
        model.Y = torch.tensor(Y_batch, dtype=torch.float32)
        model.N = Y_batch.shape[0]  # Update N to match new Y size
       
        # 10-year predictions
        print(f"\nFixed Phi - 10 year predictions...")
        fixed_10yr = evaluate_major_diseases_wsex_with_bootstrap_dynamic(
            model, Y_batch, E_batch, disease_names, pce_df_subset, 
            n_bootstraps=100, follow_up_duration_years=10, patient_indices=None
        )
        fixed_10yr['batch_idx'] = batch_idx
        fixed_10yr_results.append(fixed_10yr)
        
        # 30-year predictions
        print(f"\nFixed Phi - 30 year predictions...")
        fixed_30yr = evaluate_major_diseases_wsex_with_bootstrap_dynamic(
            model, Y_batch, E_batch, disease_names, pce_df_subset, 
            n_bootstraps=100, follow_up_duration_years=30, patient_indices=None
        )
        fixed_30yr['batch_idx'] = batch_idx
        fixed_30yr_results.append(fixed_30yr)
        
        # Static 10-year predictions (using 1-year score)
        print(f"\nFixed Phi - Static 10 year predictions (1-year score)...")
        fixed_static_10yr = evaluate_major_diseases_wsex_with_bootstrap(
            model=model,
            Y_100k=Y_batch,
            E_100k=E_batch,
            disease_names=disease_names,
            pce_df=pce_df_subset,
            n_bootstraps=100,
            follow_up_duration_years=10,
        )
        fixed_static_10yr['batch_idx'] = batch_idx
        fixed_static_10yr_results.append(fixed_static_10yr)
        
    except FileNotFoundError:
        print(f"Fixed phi checkpoint not found: {fixed_ckpt_path}")
        continue
    except Exception as e:
        print(f"Error processing fixed phi checkpoint {batch_idx}: {e}")
        continue

print(f"\n{'='*80}")
print("Completed processing all FIXED PHI checkpoints!")
print(f"{'='*80}")


In [None]:
# Compute median AUC across all batches
def compute_median_aucs(results_list):
    """Extract AUC values and compute medians for each disease"""
    if not results_list:
        return pd.DataFrame()
    
    # Get all disease names from first result (excluding batch_idx)
    disease_names_list = [k for k in results_list[0].keys() if k != 'batch_idx']
    
    auc_data = {disease: [] for disease in disease_names_list}
    
    for result in results_list:
        for disease in disease_names_list:
            if disease in result and 'auc' in result[disease]:
                auc_val = result[disease]['auc']
                if not np.isnan(auc_val):
                    auc_data[disease].append(auc_val)
    
    # Compute medians
    median_aucs = {}
    for disease, aucs in auc_data.items():
        if aucs:
            median_aucs[disease] = np.median(aucs)
        else:
            median_aucs[disease] = np.nan
    
    # Create DataFrame
    df = pd.DataFrame([median_aucs]).T
    df.columns = ['median_auc']
    df = df.sort_values('median_auc', ascending=False)
    
    return df

# Compute median AUC DataFrames
print("Computing median AUCs...")
joint_10yr_median_df = compute_median_aucs(joint_10yr_results)
joint_30yr_median_df = compute_median_aucs(joint_30yr_results)
fixed_10yr_median_df = compute_median_aucs(fixed_10yr_results)
fixed_30yr_median_df = compute_median_aucs(fixed_30yr_results)
joint_static_10yr_median_df = compute_median_aucs(joint_static_10yr_results)
fixed_static_10yr_median_df = compute_median_aucs(fixed_static_10yr_results)

print("\n" + "="*80)
print("JOINT PHI - 10 YEAR PREDICTIONS - Median AUC")
print("="*80)
print(joint_10yr_median_df)

print("\n" + "="*80)
print("JOINT PHI - 30 YEAR PREDICTIONS - Median AUC")
print("="*80)
print(joint_30yr_median_df)

print("\n" + "="*80)
print("FIXED PHI - 10 YEAR PREDICTIONS - Median AUC")
print("="*80)
print(fixed_10yr_median_df)

print("\n" + "="*80)
print("FIXED PHI - 30 YEAR PREDICTIONS - Median AUC")
print("="*80)
print(fixed_30yr_median_df)

print("\n" + "="*80)
print("JOINT PHI - STATIC 10 YEAR PREDICTIONS (1-year score) - Median AUC")
print("="*80)
print(joint_static_10yr_median_df)

print("\n" + "="*80)
print("FIXED PHI - STATIC 10 YEAR PREDICTIONS (1-year score) - Median AUC")
print("="*80)
print(fixed_static_10yr_median_df)


In [None]:

# Save to CSV files
joint_10yr_median_df.to_csv('joint_phi_10yr_median_auc.csv')
joint_30yr_median_df.to_csv('joint_phi_30yr_median_auc.csv')
fixed_10yr_median_df.to_csv('fixed_phi_10yr_median_auc.csv')
fixed_30yr_median_df.to_csv('fixed_phi_30yr_median_auc.csv')
joint_static_10yr_median_df.to_csv('joint_phi_static_10yr_median_auc.csv')
fixed_static_10yr_median_df.to_csv('fixed_phi_static_10yr_median_auc.csv')

print("\nMedian AUC DataFrames saved to CSV files.")


In [None]:
# Compute median AUC across all batches
def compute_mean_aucs(results_list):
    """Extract AUC values and compute medians for each disease"""
    if not results_list:
        return pd.DataFrame()
    
    # Get all disease names from first result (excluding batch_idx)
    disease_names_list = [k for k in results_list[0].keys() if k != 'batch_idx']
    
    auc_data = {disease: [] for disease in disease_names_list}
    
    for result in results_list:
        for disease in disease_names_list:
            if disease in result and 'auc' in result[disease]:
                auc_val = result[disease]['auc']
                if not np.isnan(auc_val):
                    auc_data[disease].append(auc_val)
    
    # Compute medians
    mean_aucs = {}
    for disease, aucs in auc_data.items():
        if aucs:
            mean_aucs[disease] = np.mean(aucs)
        else:
            mean_aucs[disease] = np.nan
    
    # Create DataFrame
    df = pd.DataFrame([mean_aucs]).T
    df.columns = ['mean_auc']
    df = df.sort_values('mean_auc', ascending=False)
    
    return df

# Compute median AUC DataFrames
print("Computing mean AUCs...")
joint_10yr_mean_df = compute_mean_aucs(joint_10yr_results)
joint_30yr_mean_df = compute_mean_aucs(joint_30yr_results)
fixed_10yr_mean_df = compute_mean_aucs(fixed_10yr_results)
fixed_30yr_mean_df = compute_mean_aucs(fixed_30yr_results)
joint_static_10yr_mean_df = compute_mean_aucs(joint_static_10yr_results)
fixed_static_10yr_mean_df = compute_mean_aucs(fixed_static_10yr_results)

print("\n" + "="*80)
print("JOINT PHI - 10 YEAR PREDICTIONS - Mean AUC")
print("="*80)
print(joint_10yr_mean_df)

print("\n" + "="*80)
print("JOINT PHI - 30 YEAR PREDICTIONS - Mean AUC")
print("="*80)
print(joint_30yr_mean_df)

print("\n" + "="*80)
print("FIXED PHI - 10 YEAR PREDICTIONS - Mean AUC")
print("="*80)
print(fixed_10yr_mean_df)

print("\n" + "="*80)
print("FIXED PHI - 30 YEAR PREDICTIONS - Mean AUC")
print("="*80)
print(fixed_30yr_mean_df)

print("\n" + "="*80)
print("JOINT PHI - STATIC 10 YEAR PREDICTIONS (1-year score) - Mean AUC")
print("="*80)
print(joint_static_10yr_mean_df)

print("\n" + "="*80)
print("FIXED PHI - STATIC 10 YEAR PREDICTIONS (1-year score) - Mean AUC")
print("="*80)
print(fixed_static_10yr_mean_df)

# Save to CSV files
joint_10yr_median_df.to_csv('joint_phi_10yr_mean_auc.csv')
joint_30yr_median_df.to_csv('joint_phi_30yr_mean_auc.csv')
fixed_10yr_median_df.to_csv('fixed_phi_10yr_mean_auc.csv')
fixed_30yr_median_df.to_csv('fixed_phi_30yr_mean_auc.csv')
joint_static_10yr_median_df.to_csv('joint_phi_static_10yr_mean_auc.csv')
fixed_static_10yr_median_df.to_csv('fixed_phi_static_10yr_mean_auc.csv')

print("\nMean AUC DataFrames saved to CSV files.")

## now washout with 1 year true washout

In [None]:
from fig5utils import *
import pandas as pd
import numpy as np
import torch


In [None]:

# Load full pce_df
pce_df_full = pd.read_csv('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/pce_prevent_full.csv')
disease_names = essentials['disease_names']

# Storage for results - Dynamic predictions
joint_10yr_results_washout = []
joint_30yr_results_washout = []
fixed_10yr_results_washout = []
fixed_30yr_results_washout = []

# Storage for results - Static predictions (1-year score for 10-year outcome)
joint_static_10yr_results_washout = []
fixed_static_10yr_results_washout = []

# Loop through checkpoints 0-40 (batch_0_10000 to batch_390000_400000)
for batch_idx in range(41):
    start_idx = batch_idx * 10000
    end_idx = (batch_idx + 1) * 10000
    
    print(f"\n{'='*80}")
    print(f"Processing batch {batch_idx}: {start_idx} to {end_idx}")
    print(f"{'='*80}")
    
    # Get pce_df subset for this batch
    pce_df_subset = pce_df_full[start_idx:end_idx].copy().reset_index(drop=True)
    
    # === JOINT PHI CHECKPOINTS ===
    joint_ckpt_path = f'/Users/sarahurbut/Library/CloudStorage/Dropbox/enrollment_prediction_jointphi_sex_pcs/enrollment_model_W0.0001_batch_{start_idx}_{end_idx}.pt'
    
    try:
        joint_ckpt = torch.load(joint_ckpt_path, weights_only=False)
        model.load_state_dict(joint_ckpt['model_state_dict'])
        
        # Use Y from checkpoint and update model.Y so forward() uses correct patients
        Y_batch = joint_ckpt['Y']
        model.Y = torch.tensor(Y_batch, dtype=torch.float32)
        model.N = Y_batch.shape[0]  # Update N to match new Y size
        
        # 10-year predictions
        print(f"\nJoint Phi - 10 year predictions...")
        joint_10yr = evaluate_major_diseases_wsex_with_bootstrap_dynamic_withwashout(
            model, Y_batch, E_100k, disease_names, pce_df_subset, washout_years=1,
            n_bootstraps=100, follow_up_duration_years=10, patient_indices=None
        )
        joint_10yr['batch_idx'] = batch_idx
        joint_10yr_results_washout.append(joint_10yr)
        
        # 30-year predictions
        print(f"\nJoint Phi - 30 year predictions...")
        joint_30yr = evaluate_major_diseases_wsex_with_bootstrap_dynamic_withwashout(
            model, Y_batch, E_100k, disease_names, pce_df_subset, washout_years=1,
            n_bootstraps=100, follow_up_duration_years=30, patient_indices=None
        )
        joint_30yr['batch_idx'] = batch_idx
        joint_30yr_results_washout.append(joint_30yr)
        
        # Static 10-year predictions (using 1-year score)
        print(f"\nJoint Phi - Static 10 year predictions (1-year score)...")
        joint_static_10yr = evaluate_major_diseases_wsex_with_bootstrap_withwashout(
            model=model,
            Y_100k=Y_batch,
            E_100k=E_100k,
            disease_names=disease_names,
            pce_df=pce_df_subset,washout_years=1,
            n_bootstraps=100,
            follow_up_duration_years=10,
        )
        joint_static_10yr['batch_idx'] = batch_idx
        joint_static_10yr_results_washout.append(joint_static_10yr)
        
    except FileNotFoundError:
        print(f"Joint phi checkpoint not found: {joint_ckpt_path}")
        continue
    except Exception as e:
        print(f"Error processing joint phi checkpoint {batch_idx}: {e}")
        continue
    
    # === FIXED PHI CHECKPOINTS ===
    fixed_ckpt_path = f'/Users/sarahurbut/Library/CloudStorage/Dropbox/enrollment_predictions_fixedphi_withpcs_fromclaudeoutput/model_enroll_fixedphi_sex_{start_idx}_{end_idx}.pt'
    
    try:
        fixed_ckpt = torch.load(fixed_ckpt_path, weights_only=False)
        model.load_state_dict(fixed_ckpt['model_state_dict'])
        
        # For fixed phi, Y is not saved in checkpoint, so extract from full Y tensor
        # Load full Y tensor if not already available
        if 'Y_full' not in globals():
            Y_full = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/Y_tensor.pt')
        if 'E_full' not in globals():
            E_full = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/E_enrollment_full.pt')
        
        # Extract batch from full tensors
        Y_batch = Y_full[start_idx:end_idx]
        E_batch = E_full[start_idx:end_idx]
        
        # Update model.Y and model.N so forward() uses correct patients
        model.Y = torch.tensor(Y_batch, dtype=torch.float32)
        model.N = Y_batch.shape[0]  # Update N to match new Y size
       
        # 10-year predictions
        print(f"\nFixed Phi - 10 year predictions...")
        fixed_10yr = evaluate_major_diseases_wsex_with_bootstrap_dynamic_withwashout(
            model, Y_batch, E_batch, disease_names, pce_df_subset, washout_years=1,
            n_bootstraps=100, follow_up_duration_years=10, patient_indices=None
        )
        fixed_10yr['batch_idx'] = batch_idx
        fixed_10yr_results_washout.append(fixed_10yr)
        
        # 30-year predictions
        print(f"\nFixed Phi - 30 year predictions...")
        fixed_30yr = evaluate_major_diseases_wsex_with_bootstrap_dynamic_withwashout(
            model, Y_batch, E_batch, disease_names, pce_df_subset, washout_years=1,
            n_bootstraps=100, follow_up_duration_years=30, patient_indices=None
        )
        fixed_30yr['batch_idx'] = batch_idx
        fixed_30yr_results_washout.append(fixed_30yr)
        
        # Static 10-year predictions (using 1-year score)
        print(f"\nFixed Phi - Static 10 year predictions (1-year score)...")
        fixed_static_10yr = evaluate_major_diseases_wsex_with_bootstrap_withwashout(
            model=model,
            Y_100k=Y_batch,
            E_100k=E_batch,
            disease_names=disease_names,
            pce_df=pce_df_subset,
            washout_years=1,
            n_bootstraps=100,
            follow_up_duration_years=10,
        )
        fixed_static_10yr['batch_idx'] = batch_idx
        fixed_static_10yr_results_washout.append(fixed_static_10yr)
        
    except FileNotFoundError:
        print(f"Fixed phi checkpoint not found: {fixed_ckpt_path}")
        continue
    except Exception as e:
        print(f"Error processing fixed phi checkpoint {batch_idx}: {e}")
        continue

print(f"\n{'='*80}")
print("Completed processing all checkpoints!")
print(f"{'='*80}")


In [None]:
# Compute and save median AUCs for washout results

def compute_median_aucs(results_list):
    """Extract AUC values and compute medians for each disease"""
    if not results_list:
        print("Warning: Empty results list!")
        return pd.DataFrame()
    
    # Get all disease names from first result (excluding batch_idx)
    disease_names_list = [k for k in results_list[0].keys() if k != 'batch_idx']
    
    if not disease_names_list:
        print("Warning: No disease names found in results!")
        return pd.DataFrame()
    
    auc_data = {disease: [] for disease in disease_names_list}
    
    for result in results_list:
        for disease in disease_names_list:
            if disease in result:
                if isinstance(result[disease], dict) and 'auc' in result[disease]:
                    auc_value = result[disease]['auc']
                    if not np.isnan(auc_value):
                        auc_data[disease].append(auc_value)
                elif isinstance(result[disease], (int, float)):
                    # Handle case where result[disease] is directly the AUC
                    if not np.isnan(result[disease]):
                        auc_data[disease].append(result[disease])
    
    # Compute medians
    median_aucs = {disease: np.median(auc_data[disease]) if auc_data[disease] else np.nan 
                   for disease in disease_names_list}
    
    # Create DataFrame
    df = pd.DataFrame(list(median_aucs.items()), columns=['disease', 'median_auc'])
    df = df.set_index('disease').sort_values('median_auc', ascending=False)
    
    return df

# Debug: Check what variables exist
print("Checking available variables...")
print(f"joint_10yr_results_washout: {len(joint_10yr_results_washout) if 'joint_10yr_results_washout' in globals() else 'NOT FOUND'}")
print(f"joint_30yr_results_washout: {len(joint_30yr_results_washout) if 'joint_30yr_results_washout' in globals() else 'NOT FOUND'}")
print(f"joint_static_10yr_results_washout: {len(joint_static_10yr_results_washout) if 'joint_static_10yr_results_washout' in globals() else 'NOT FOUND'}")
print(f"fixed_10yr_results_washout: {len(fixed_10yr_results_washout) if 'fixed_10yr_results_washout' in globals() else 'NOT FOUND'}")
print(f"fixed_30yr_results_washout: {len(fixed_30yr_results_washout) if 'fixed_30yr_results_washout' in globals() else 'NOT FOUND'}")
print(f"fixed_static_10yr_results_washout: {len(fixed_static_10yr_results_washout) if 'fixed_static_10yr_results_washout' in globals() else 'NOT FOUND'}")

# Compute medians for all washout results
joint_10yr_washout_median_df = compute_median_aucs(joint_10yr_results_washout)
joint_30yr_washout_median_df = compute_median_aucs(joint_30yr_results_washout)
joint_static_10yr_washout_median_df = compute_median_aucs(joint_static_10yr_results_washout)
fixed_10yr_washout_median_df = compute_median_aucs(fixed_10yr_results_washout)
fixed_30yr_washout_median_df = compute_median_aucs(fixed_30yr_results_washout)
fixed_static_10yr_washout_median_df = compute_median_aucs(fixed_static_10yr_results_washout)

print("\n" + "="*80)
print("JOINT PHI - 10 YEAR PREDICTIONS (WITH WASHOUT) - Median AUC")
print("="*80)
print(joint_10yr_washout_median_df)

print("\n" + "="*80)
print("JOINT PHI - 30 YEAR PREDICTIONS (WITH WASHOUT) - Median AUC")
print("="*80)
print(joint_30yr_washout_median_df)

print("\n" + "="*80)
print("JOINT PHI - STATIC 10 YEAR PREDICTIONS (WITH WASHOUT) - Median AUC")
print("="*80)
print(joint_static_10yr_washout_median_df)

print("\n" + "="*80)
print("FIXED PHI - 10 YEAR PREDICTIONS (WITH WASHOUT) - Median AUC")
print("="*80)
print(fixed_10yr_washout_median_df)

print("\n" + "="*80)
print("FIXED PHI - 30 YEAR PREDICTIONS (WITH WASHOUT) - Median AUC")
print("="*80)
print(fixed_30yr_washout_median_df)

print("\n" + "="*80)
print("FIXED PHI - STATIC 10 YEAR PREDICTIONS (WITH WASHOUT) - Median AUC")
print("="*80)
print(fixed_static_10yr_washout_median_df)

# Save to CSV files
joint_10yr_washout_median_df.to_csv('joint_phi_10yr_median_auc_washout.csv')
joint_30yr_washout_median_df.to_csv('joint_phi_30yr_median_auc_washout.csv')
joint_static_10yr_washout_median_df.to_csv('joint_phi_static_10yr_median_auc_washout.csv')
fixed_10yr_washout_median_df.to_csv('fixed_phi_10yr_median_auc_washout.csv')
fixed_30yr_washout_median_df.to_csv('fixed_phi_30yr_median_auc_washout.csv')
fixed_static_10yr_washout_median_df.to_csv('fixed_phi_static_10yr_median_auc_washout.csv')

print("\nMedian AUC DataFrames (with washout) saved to CSV files.")

In [None]:
# Compare 10-year predictions: Joint vs Fixed, Washout vs Non-Washout

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Load all the CSV files
joint_10yr = pd.read_csv('joint_phi_10yr_median_auc.csv', index_col=0)
joint_10yr_washout = pd.read_csv('joint_phi_10yr_median_auc_washout.csv', index_col=0)
fixed_10yr = pd.read_csv('fixed_phi_10yr_median_auc.csv', index_col=0)
fixed_10yr_washout = pd.read_csv('fixed_phi_10yr_median_auc_washout.csv', index_col=0)

# Get common diseases across all datasets
all_diseases = set(joint_10yr.index) & set(joint_10yr_washout.index) & set(fixed_10yr.index) & set(fixed_10yr_washout.index)
diseases_sorted = sorted(list(all_diseases))

# Create comparison DataFrame
comparison_df = pd.DataFrame({
    'Joint_NoWashout': joint_10yr.loc[diseases_sorted, 'median_auc'],
    'Joint_Washout': joint_10yr_washout.loc[diseases_sorted, 'median_auc'],
    'Fixed_NoWashout': fixed_10yr.loc[diseases_sorted, 'median_auc'],
    'Fixed_Washout': fixed_10yr_washout.loc[diseases_sorted, 'median_auc']
}, index=diseases_sorted)

# Calculate differences
comparison_df['Joint_Diff'] = comparison_df['Joint_Washout'] - comparison_df['Joint_NoWashout']
comparison_df['Fixed_Diff'] = comparison_df['Fixed_Washout'] - comparison_df['Fixed_NoWashout']
comparison_df['Joint_vs_Fixed_NoWashout'] = comparison_df['Joint_NoWashout'] - comparison_df['Fixed_NoWashout']
comparison_df['Joint_vs_Fixed_Washout'] = comparison_df['Joint_Washout'] - comparison_df['Fixed_Washout']

print("="*80)
print("10-YEAR PREDICTIONS COMPARISON TABLE")
print("="*80)
print(comparison_df.round(3))

# Save comparison table
comparison_df.to_csv('10yr_comparison_table.csv')
print("\n✓ Comparison table saved to '10yr_comparison_table.csv'")

# Visualization 1: Scatter plot - Washout vs No Washout
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Joint Phi
axes[0].scatter(comparison_df['Joint_NoWashout'], comparison_df['Joint_Washout'], alpha=0.7, s=100)
axes[0].plot([0.4, 0.8], [0.4, 0.8], 'r--', linewidth=1, label='y=x')
axes[0].set_xlabel('No Washout AUC', fontsize=12)
axes[0].set_ylabel('Washout AUC', fontsize=12)
axes[0].set_title('Joint Phi: Washout vs No Washout', fontsize=14, fontweight='bold')
axes[0].grid(True, alpha=0.3)
axes[0].legend()
axes[0].set_aspect('equal', adjustable='box')

# Fixed Phi
axes[1].scatter(comparison_df['Fixed_NoWashout'], comparison_df['Fixed_Washout'], alpha=0.7, s=100, color='green')
axes[1].plot([0.4, 0.8], [0.4, 0.8], 'r--', linewidth=1, label='y=x')
axes[1].set_xlabel('No Washout AUC', fontsize=12)
axes[1].set_ylabel('Washout AUC', fontsize=12)
axes[1].set_title('Fixed Phi: Washout vs No Washout', fontsize=14, fontweight='bold')
axes[1].grid(True, alpha=0.3)
axes[1].legend()
axes[1].set_aspect('equal', adjustable='box')

plt.tight_layout()
plt.savefig('washout_comparison_scatter.png', dpi=300, bbox_inches='tight')
plt.show()

# Visualization 2: Grouped bar chart for top diseases
top_n = 15
top_diseases = comparison_df.nlargest(top_n, 'Joint_NoWashout').index

fig, ax = plt.subplots(figsize=(14, 8))
x = np.arange(len(top_diseases))
width = 0.2

bars1 = ax.bar(x - 1.5*width, comparison_df.loc[top_diseases, 'Joint_NoWashout'], width, 
               label='Joint, No Washout', alpha=0.8)
bars2 = ax.bar(x - 0.5*width, comparison_df.loc[top_diseases, 'Joint_Washout'], width, 
               label='Joint, Washout', alpha=0.8)
bars3 = ax.bar(x + 0.5*width, comparison_df.loc[top_diseases, 'Fixed_NoWashout'], width, 
               label='Fixed, No Washout', alpha=0.8)
bars4 = ax.bar(x + 1.5*width, comparison_df.loc[top_diseases, 'Fixed_Washout'], width, 
               label='Fixed, Washout', alpha=0.8)

ax.set_xlabel('Disease', fontsize=12)
ax.set_ylabel('Median AUC', fontsize=12)
ax.set_title(f'10-Year Predictions: Top {top_n} Diseases\n(Joint vs Fixed, Washout vs No Washout)', 
             fontsize=14, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels(top_diseases, rotation=45, ha='right')
ax.legend(loc='upper right')
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig('10yr_comparison_barchart.png', dpi=300, bbox_inches='tight')
plt.show()

# Visualization 3: Difference plot (Washout - No Washout)
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Joint Phi differences
joint_diffs = comparison_df['Joint_Diff'].sort_values()
axes[0].barh(range(len(joint_diffs)), joint_diffs.values, alpha=0.7)
axes[0].axvline(x=0, color='red', linestyle='--', linewidth=1)
axes[0].set_yticks(range(len(joint_diffs)))
axes[0].set_yticklabels(joint_diffs.index)
axes[0].set_xlabel('AUC Difference (Washout - No Washout)', fontsize=12)
axes[0].set_title('Joint Phi: Effect of Washout', fontsize=14, fontweight='bold')
axes[0].grid(True, alpha=0.3, axis='x')

# Fixed Phi differences
fixed_diffs = comparison_df['Fixed_Diff'].sort_values()
axes[1].barh(range(len(fixed_diffs)), fixed_diffs.values, alpha=0.7, color='green')
axes[1].axvline(x=0, color='red', linestyle='--', linewidth=1)
axes[1].set_yticks(range(len(fixed_diffs)))
axes[1].set_yticklabels(fixed_diffs.index)
axes[1].set_xlabel('AUC Difference (Washout - No Washout)', fontsize=12)
axes[1].set_title('Fixed Phi: Effect of Washout', fontsize=14, fontweight='bold')
axes[1].grid(True, alpha=0.3, axis='x')

plt.tight_layout()
plt.savefig('washout_differences.png', dpi=300, bbox_inches='tight')
plt.show()

# Summary statistics
print("\n" + "="*80)
print("SUMMARY STATISTICS")
print("="*80)
print(f"\nJoint Phi - Washout Effect:")
print(f"  Mean difference: {comparison_df['Joint_Diff'].mean():.4f}")
print(f"  Median difference: {comparison_df['Joint_Diff'].median():.4f}")
print(f"  Std difference: {comparison_df['Joint_Diff'].std():.4f}")
print(f"  Range: [{comparison_df['Joint_Diff'].min():.4f}, {comparison_df['Joint_Diff'].max():.4f}]")

print(f"\nFixed Phi - Washout Effect:")
print(f"  Mean difference: {comparison_df['Fixed_Diff'].mean():.4f}")
print(f"  Median difference: {comparison_df['Fixed_Diff'].median():.4f}")
print(f"  Std difference: {comparison_df['Fixed_Diff'].std():.4f}")
print(f"  Range: [{comparison_df['Fixed_Diff'].min():.4f}, {comparison_df['Fixed_Diff'].max():.4f}]")

print(f"\nJoint vs Fixed (No Washout):")
print(f"  Mean difference: {comparison_df['Joint_vs_Fixed_NoWashout'].mean():.4f}")
print(f"  Median difference: {comparison_df['Joint_vs_Fixed_NoWashout'].median():.4f}")

print(f"\nJoint vs Fixed (Washout):")
print(f"  Mean difference: {comparison_df['Joint_vs_Fixed_Washout'].mean():.4f}")
print(f"  Median difference: {comparison_df['Joint_vs_Fixed_Washout'].median():.4f}")

print("\n✓ Visualizations saved!")


In [None]:
# Compare 10-year vs 30-year predictions: Joint vs Fixed Phi (with disease labels)

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# Load all CSV files
joint_10yr = pd.read_csv('joint_phi_10yr_median_auc.csv', index_col=0)
joint_30yr = pd.read_csv('joint_phi_30yr_median_auc.csv', index_col=0)
fixed_10yr = pd.read_csv('fixed_phi_10yr_median_auc.csv', index_col=0)
fixed_30yr = pd.read_csv('fixed_phi_30yr_median_auc.csv', index_col=0)

joint_10yr_washout = pd.read_csv('joint_phi_10yr_median_auc_washout.csv', index_col=0)
joint_30yr_washout = pd.read_csv('joint_phi_30yr_median_auc_washout.csv', index_col=0)
fixed_10yr_washout = pd.read_csv('fixed_phi_10yr_median_auc_washout.csv', index_col=0)
fixed_30yr_washout = pd.read_csv('fixed_phi_30yr_median_auc_washout.csv', index_col=0)

# Get common diseases
all_diseases = set(joint_10yr.index) & set(joint_30yr.index) & set(fixed_10yr.index) & set(fixed_30yr.index)
diseases_sorted = sorted(list(all_diseases))

# Create figure with 4 subplots
fig, axes = plt.subplots(2, 2, figsize=(16, 14))

# Joint Phi - No Washout
x_vals = joint_10yr.loc[diseases_sorted, 'median_auc']
y_vals = joint_30yr.loc[diseases_sorted, 'median_auc']
axes[0, 0].scatter(x_vals, y_vals, alpha=0.7, s=100, edgecolors='black', linewidth=0.5)
for disease in diseases_sorted:
    axes[0, 0].annotate(disease, 
                        (joint_10yr.loc[disease, 'median_auc'], joint_30yr.loc[disease, 'median_auc']),
                        fontsize=8, alpha=0.7, ha='center', va='bottom')
axes[0, 0].plot([0.4, 0.8], [0.4, 0.8], 'r--', linewidth=1, label='y=x')
axes[0, 0].set_xlabel('10-Year AUC', fontsize=12)
axes[0, 0].set_ylabel('30-Year AUC', fontsize=12)
axes[0, 0].set_title('Joint Phi (No Washout)', fontsize=14, fontweight='bold')
axes[0, 0].grid(True, alpha=0.3)
axes[0, 0].legend()
axes[0, 0].set_aspect('equal', adjustable='box')

# Joint Phi - With Washout
x_vals = joint_10yr_washout.loc[diseases_sorted, 'median_auc']
y_vals = joint_30yr_washout.loc[diseases_sorted, 'median_auc']
axes[0, 1].scatter(x_vals, y_vals, alpha=0.7, s=100, edgecolors='black', linewidth=0.5)
for disease in diseases_sorted:
    axes[0, 1].annotate(disease, 
                        (joint_10yr_washout.loc[disease, 'median_auc'], joint_30yr_washout.loc[disease, 'median_auc']),
                        fontsize=8, alpha=0.7, ha='center', va='bottom')
axes[0, 1].plot([0.4, 0.8], [0.4, 0.8], 'r--', linewidth=1, label='y=x')
axes[0, 1].set_xlabel('10-Year AUC', fontsize=12)
axes[0, 1].set_ylabel('30-Year AUC', fontsize=12)
axes[0, 1].set_title('Joint Phi (With Washout)', fontsize=14, fontweight='bold')
axes[0, 1].grid(True, alpha=0.3)
axes[0, 1].legend()
axes[0, 1].set_aspect('equal', adjustable='box')

# Fixed Phi - No Washout
x_vals = fixed_10yr.loc[diseases_sorted, 'median_auc']
y_vals = fixed_30yr.loc[diseases_sorted, 'median_auc']
axes[1, 0].scatter(x_vals, y_vals, alpha=0.7, s=100, color='green', edgecolors='black', linewidth=0.5)
for disease in diseases_sorted:
    axes[1, 0].annotate(disease, 
                        (fixed_10yr.loc[disease, 'median_auc'], fixed_30yr.loc[disease, 'median_auc']),
                        fontsize=8, alpha=0.7, ha='center', va='bottom')
axes[1, 0].plot([0.4, 0.8], [0.4, 0.8], 'r--', linewidth=1, label='y=x')
axes[1, 0].set_xlabel('10-Year AUC', fontsize=12)
axes[1, 0].set_ylabel('30-Year AUC', fontsize=12)
axes[1, 0].set_title('Fixed Phi (No Washout)', fontsize=14, fontweight='bold')
axes[1, 0].grid(True, alpha=0.3)
axes[1, 0].legend()
axes[1, 0].set_aspect('equal', adjustable='box')

# Fixed Phi - With Washout
x_vals = fixed_10yr_washout.loc[diseases_sorted, 'median_auc']
y_vals = fixed_30yr_washout.loc[diseases_sorted, 'median_auc']
axes[1, 1].scatter(x_vals, y_vals, alpha=0.7, s=100, color='green', edgecolors='black', linewidth=0.5)
for disease in diseases_sorted:
    axes[1, 1].annotate(disease, 
                        (fixed_10yr_washout.loc[disease, 'median_auc'], fixed_30yr_washout.loc[disease, 'median_auc']),
                        fontsize=8, alpha=0.7, ha='center', va='bottom')
axes[1, 1].plot([0.4, 0.8], [0.4, 0.8], 'r--', linewidth=1, label='y=x')
axes[1, 1].set_xlabel('10-Year AUC', fontsize=12)
axes[1, 1].set_ylabel('30-Year AUC', fontsize=12)
axes[1, 1].set_title('Fixed Phi (With Washout)', fontsize=14, fontweight='bold')
axes[1, 1].grid(True, alpha=0.3)
axes[1, 1].legend()
axes[1, 1].set_aspect('equal', adjustable='box')

plt.suptitle('10-Year vs 30-Year Predictions', fontsize=16, fontweight='bold', y=0.995)
plt.tight_layout()
plt.savefig('10yr_vs_30yr_scatter_labeled.png', dpi=300, bbox_inches='tight')
plt.show()

# Calculate correlations
print("="*80)
print("10-YEAR vs 30-YEAR CORRELATIONS")
print("="*80)
print(f"\nJoint Phi (No Washout): {np.corrcoef(joint_10yr.loc[diseases_sorted, 'median_auc'], joint_30yr.loc[diseases_sorted, 'median_auc'])[0,1]:.3f}")
print(f"Joint Phi (With Washout): {np.corrcoef(joint_10yr_washout.loc[diseases_sorted, 'median_auc'], joint_30yr_washout.loc[diseases_sorted, 'median_auc'])[0,1]:.3f}")
print(f"Fixed Phi (No Washout): {np.corrcoef(fixed_10yr.loc[diseases_sorted, 'median_auc'], fixed_30yr.loc[diseases_sorted, 'median_auc'])[0,1]:.3f}")
print(f"Fixed Phi (With Washout): {np.corrcoef(fixed_10yr_washout.loc[diseases_sorted, 'median_auc'], fixed_30yr_washout.loc[diseases_sorted, 'median_auc'])[0,1]:.3f}")

# Calculate mean differences
print("\n" + "="*80)
print("MEAN DIFFERENCES (30yr - 10yr)")
print("="*80)
joint_diff_nowash = (joint_30yr.loc[diseases_sorted, 'median_auc'] - joint_10yr.loc[diseases_sorted, 'median_auc']).mean()
joint_diff_wash = (joint_30yr_washout.loc[diseases_sorted, 'median_auc'] - joint_10yr_washout.loc[diseases_sorted, 'median_auc']).mean()
fixed_diff_nowash = (fixed_30yr.loc[diseases_sorted, 'median_auc'] - fixed_10yr.loc[diseases_sorted, 'median_auc']).mean()
fixed_diff_wash = (fixed_30yr_washout.loc[diseases_sorted, 'median_auc'] - fixed_10yr_washout.loc[diseases_sorted, 'median_auc']).mean()

print(f"\nJoint Phi (No Washout): {joint_diff_nowash:.4f}")
print(f"Joint Phi (With Washout): {joint_diff_wash:.4f}")
print(f"Fixed Phi (No Washout): {fixed_diff_nowash:.4f}")
print(f"Fixed Phi (With Washout): {fixed_diff_wash:.4f}")

print("\n✓ Plot saved as '10yr_vs_30yr_scatter_labeled.png'")

In [None]:
# Aggregate bootstrap CIs across batches
# Option 1: Use existing ci_lower and ci_upper (median across batches)

def compute_aggregated_cis(results_list, name=""):
    """Extract CI bounds and aggregate across batches"""
    if not results_list:
        return pd.DataFrame()
    
    # Get all disease names
    disease_names_list = [k for k in results_list[0].keys() if k != 'batch_idx']
    
    ci_data = {disease: {'ci_lowers': [], 'ci_uppers': [], 'aucs': []} 
               for disease in disease_names_list}
    
    for result in results_list:
        for disease in disease_names_list:
            if disease in result:
                if isinstance(result[disease], dict):
                    if 'ci_lower' in result[disease] and not np.isnan(result[disease]['ci_lower']):
                        ci_data[disease]['ci_lowers'].append(result[disease]['ci_lower'])
                    if 'ci_upper' in result[disease] and not np.isnan(result[disease]['ci_upper']):
                        ci_data[disease]['ci_uppers'].append(result[disease]['ci_upper'])
                    if 'auc' in result[disease] and not np.isnan(result[disease]['auc']):
                        ci_data[disease]['aucs'].append(result[disease]['auc'])
    
    # Aggregate: median of bounds and median AUC
    aggregated = {}
    for disease in disease_names_list:
        if ci_data[disease]['aucs']:
            aggregated[disease] = {
                'median_auc': np.median(ci_data[disease]['aucs']),
                'ci_lower_median': np.median(ci_data[disease]['ci_lowers']) if ci_data[disease]['ci_lowers'] else np.nan,
                'ci_upper_median': np.median(ci_data[disease]['ci_uppers']) if ci_data[disease]['ci_uppers'] else np.nan,
                'ci_lower_min': np.min(ci_data[disease]['ci_lowers']) if ci_data[disease]['ci_lowers'] else np.nan,
                'ci_upper_max': np.max(ci_data[disease]['ci_uppers']) if ci_data[disease]['ci_uppers'] else np.nan,
                'n_batches': len(ci_data[disease]['aucs'])
            }
        else:
            aggregated[disease] = {
                'median_auc': np.nan,
                'ci_lower_median': np.nan,
                'ci_upper_median': np.nan,
                'ci_lower_min': np.nan,
                'ci_upper_max': np.nan,
                'n_batches': 0
            }
    
    df = pd.DataFrame(aggregated).T
    df = df.sort_values('median_auc', ascending=False)
    
    return df

# Compute aggregated CIs for all result sets
print("="*80)
print("AGGREGATING BOOTSTRAP CIs ACROSS BATCHES")
print("="*80)

joint_10yr_aggregated = compute_aggregated_cis(joint_10yr_results, "Joint 10yr")
joint_30yr_aggregated = compute_aggregated_cis(joint_30yr_results, "Joint 30yr")
fixed_10yr_aggregated = compute_aggregated_cis(fixed_10yr_results, "Fixed 10yr")
fixed_30yr_aggregated = compute_aggregated_cis(fixed_30yr_results, "Fixed 30yr")

joint_10yr_washout_aggregated = compute_aggregated_cis(joint_10yr_results_washout, "Joint 10yr Washout")
joint_30yr_washout_aggregated = compute_aggregated_cis(joint_30yr_results_washout, "Joint 30yr Washout")
fixed_10yr_washout_aggregated = compute_aggregated_cis(fixed_10yr_results_washout, "Fixed 10yr Washout")
fixed_30yr_washout_aggregated = compute_aggregated_cis(fixed_30yr_results_washout, "Fixed 30yr Washout")

# Display
print("\nJoint Phi - 10yr (No Washout):")
print(joint_10yr_aggregated[['median_auc', 'ci_lower_median', 'ci_upper_median', 'n_batches']].round(3))

print("\nJoint Phi - 10yr (With Washout):")
print(joint_10yr_washout_aggregated[['median_auc', 'ci_lower_median', 'ci_upper_median', 'n_batches']].round(3))

# Save aggregated results
joint_10yr_aggregated.to_csv('joint_phi_10yr_aggregated_cis.csv')
joint_30yr_aggregated.to_csv('joint_phi_30yr_aggregated_cis.csv')
fixed_10yr_aggregated.to_csv('fixed_phi_10yr_aggregated_cis.csv')
fixed_30yr_aggregated.to_csv('fixed_phi_30yr_aggregated_cis.csv')

joint_10yr_washout_aggregated.to_csv('joint_phi_10yr_washout_aggregated_cis.csv')
joint_30yr_washout_aggregated.to_csv('joint_phi_30yr_washout_aggregated_cis.csv')
fixed_10yr_washout_aggregated.to_csv('fixed_phi_10yr_washout_aggregated_cis.csv')
fixed_30yr_washout_aggregated.to_csv('fixed_phi_30yr_washout_aggregated_cis.csv')

print("\n✓ Aggregated CI results saved to CSV files")

In [None]:
# Plot median AUCs with aggregated bootstrap CI intervals
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# Load aggregated CI results (or compute them if not already done)
# Assuming we have the aggregated results...

# If aggregated results not computed yet, compute them
def compute_aggregated_cis(results_list):
    """Extract CI bounds and aggregate across batches"""
    if not results_list:
        return pd.DataFrame()
    
    disease_names_list = [k for k in results_list[0].keys() if k != 'batch_idx']
    
    ci_data = {disease: {'ci_lowers': [], 'ci_uppers': [], 'aucs': []} 
               for disease in disease_names_list}
    
    for result in results_list:
        for disease in disease_names_list:
            if disease in result and isinstance(result[disease], dict):
                if 'ci_lower' in result[disease] and not np.isnan(result[disease]['ci_lower']):
                    ci_data[disease]['ci_lowers'].append(result[disease]['ci_lower'])
                if 'ci_upper' in result[disease] and not np.isnan(result[disease]['ci_upper']):
                    ci_data[disease]['ci_uppers'].append(result[disease]['ci_upper'])
                if 'auc' in result[disease] and not np.isnan(result[disease]['auc']):
                    ci_data[disease]['aucs'].append(result[disease]['auc'])
    
    aggregated = {}
    for disease in disease_names_list:
        if ci_data[disease]['aucs']:
            aggregated[disease] = {
                'median_auc': np.median(ci_data[disease]['aucs']),
                'ci_lower_median': np.median(ci_data[disease]['ci_lowers']) if ci_data[disease]['ci_lowers'] else np.nan,
                'ci_upper_median': np.median(ci_data[disease]['ci_uppers']) if ci_data[disease]['ci_uppers'] else np.nan,
                'n_batches': len(ci_data[disease]['aucs'])
            }
        else:
            aggregated[disease] = {
                'median_auc': np.nan,
                'ci_lower_median': np.nan,
                'ci_upper_median': np.nan,
                'n_batches': 0
            }
    
    df = pd.DataFrame(aggregated).T
    df = df.sort_values('median_auc', ascending=False)
    return df

# Compute aggregated CIs
joint_10yr_agg = compute_aggregated_cis(joint_10yr_results)
joint_30yr_agg = compute_aggregated_cis(joint_30yr_results)
fixed_10yr_agg = compute_aggregated_cis(fixed_10yr_results)
fixed_30yr_agg = compute_aggregated_cis(fixed_30yr_results)

joint_10yr_washout_agg = compute_aggregated_cis(joint_10yr_results_washout)
joint_30yr_washout_agg = compute_aggregated_cis(joint_30yr_results_washout)
fixed_10yr_washout_agg = compute_aggregated_cis(fixed_10yr_results_washout)
fixed_30yr_washout_agg = compute_aggregated_cis(fixed_30yr_results_washout)

# Get common diseases
all_diseases = set(joint_10yr_agg.index) & set(joint_30yr_agg.index) & set(fixed_10yr_agg.index) & set(fixed_30yr_agg.index)
diseases_sorted = sorted(list(all_diseases))

# Create comparison plot with error bars
fig, axes = plt.subplots(2, 2, figsize=(16, 12))

# Joint Phi - No Washout
x_pos = np.arange(len(diseases_sorted))
medians_10 = joint_10yr_agg.loc[diseases_sorted, 'median_auc'].values
medians_30 = joint_30yr_agg.loc[diseases_sorted, 'median_auc'].values
lower_err_10 = (medians_10 - joint_10yr_agg.loc[diseases_sorted, 'ci_lower_median'].values)
upper_err_10 = (joint_10yr_agg.loc[diseases_sorted, 'ci_upper_median'].values - medians_10)
lower_err_30 = (medians_30 - joint_30yr_agg.loc[diseases_sorted, 'ci_lower_median'].values)
upper_err_30 = (joint_30yr_agg.loc[diseases_sorted, 'ci_upper_median'].values - medians_30)

axes[0, 0].errorbar(medians_10, medians_30, 
                    xerr=[lower_err_10, upper_err_10], 
                    yerr=[lower_err_30, upper_err_30],
                    fmt='o', alpha=0.6, capsize=3, capthick=1, elinewidth=1)
for i, disease in enumerate(diseases_sorted):
    axes[0, 0].annotate(disease, (medians_10[i], medians_30[i]),
                       fontsize=7, alpha=0.7, ha='center', va='bottom')
axes[0, 0].plot([0.4, 0.8], [0.4, 0.8], 'r--', linewidth=1, label='y=x')
axes[0, 0].set_xlabel('10-Year AUC (with 95% CI)', fontsize=11)
axes[0, 0].set_ylabel('30-Year AUC (with 95% CI)', fontsize=11)
axes[0, 0].set_title('Joint Phi (No Washout)', fontsize=13, fontweight='bold')
axes[0, 0].grid(True, alpha=0.3)
axes[0, 0].legend()
axes[0, 0].set_aspect('equal', adjustable='box')

# Joint Phi - With Washout
medians_10_w = joint_10yr_washout_agg.loc[diseases_sorted, 'median_auc'].values
medians_30_w = joint_30yr_washout_agg.loc[diseases_sorted, 'median_auc'].values
lower_err_10_w = (medians_10_w - joint_10yr_washout_agg.loc[diseases_sorted, 'ci_lower_median'].values)
upper_err_10_w = (joint_10yr_washout_agg.loc[diseases_sorted, 'ci_upper_median'].values - medians_10_w)
lower_err_30_w = (medians_30_w - joint_30yr_washout_agg.loc[diseases_sorted, 'ci_lower_median'].values)
upper_err_30_w = (joint_30yr_washout_agg.loc[diseases_sorted, 'ci_upper_median'].values - medians_30_w)

axes[0, 1].errorbar(medians_10_w, medians_30_w,
                    xerr=[lower_err_10_w, upper_err_10_w],
                    yerr=[lower_err_30_w, upper_err_30_w],
                    fmt='o', alpha=0.6, capsize=3, capthick=1, elinewidth=1)
for i, disease in enumerate(diseases_sorted):
    axes[0, 1].annotate(disease, (medians_10_w[i], medians_30_w[i]),
                       fontsize=7, alpha=0.7, ha='center', va='bottom')
axes[0, 1].plot([0.4, 0.8], [0.4, 0.8], 'r--', linewidth=1, label='y=x')
axes[0, 1].set_xlabel('10-Year AUC (with 95% CI)', fontsize=11)
axes[0, 1].set_ylabel('30-Year AUC (with 95% CI)', fontsize=11)
axes[0, 1].set_title('Joint Phi (With Washout)', fontsize=13, fontweight='bold')
axes[0, 1].grid(True, alpha=0.3)
axes[0, 1].legend()
axes[0, 1].set_aspect('equal', adjustable='box')

# Fixed Phi - No Washout
medians_10_f = fixed_10yr_agg.loc[diseases_sorted, 'median_auc'].values
medians_30_f = fixed_30yr_agg.loc[diseases_sorted, 'median_auc'].values
lower_err_10_f = (medians_10_f - fixed_10yr_agg.loc[diseases_sorted, 'ci_lower_median'].values)
upper_err_10_f = (fixed_10yr_agg.loc[diseases_sorted, 'ci_upper_median'].values - medians_10_f)
lower_err_30_f = (medians_30_f - fixed_30yr_agg.loc[diseases_sorted, 'ci_lower_median'].values)
upper_err_30_f = (fixed_30yr_agg.loc[diseases_sorted, 'ci_upper_median'].values - medians_30_f)

axes[1, 0].errorbar(medians_10_f, medians_30_f,
                    xerr=[lower_err_10_f, upper_err_10_f],
                    yerr=[lower_err_30_f, upper_err_30_f],
                    fmt='o', alpha=0.6, capsize=3, capthick=1, elinewidth=1, color='green')
for i, disease in enumerate(diseases_sorted):
    axes[1, 0].annotate(disease, (medians_10_f[i], medians_30_f[i]),
                       fontsize=7, alpha=0.7, ha='center', va='bottom')
axes[1, 0].plot([0.4, 0.8], [0.4, 0.8], 'r--', linewidth=1, label='y=x')
axes[1, 0].set_xlabel('10-Year AUC (with 95% CI)', fontsize=11)
axes[1, 0].set_ylabel('30-Year AUC (with 95% CI)', fontsize=11)
axes[1, 0].set_title('Fixed Phi (No Washout)', fontsize=13, fontweight='bold')
axes[1, 0].grid(True, alpha=0.3)
axes[1, 0].legend()
axes[1, 0].set_aspect('equal', adjustable='box')

# Fixed Phi - With Washout
medians_10_fw = fixed_10yr_washout_agg.loc[diseases_sorted, 'median_auc'].values
medians_30_fw = fixed_30yr_washout_agg.loc[diseases_sorted, 'median_auc'].values
lower_err_10_fw = (medians_10_fw - fixed_10yr_washout_agg.loc[diseases_sorted, 'ci_lower_median'].values)
upper_err_10_fw = (fixed_10yr_washout_agg.loc[diseases_sorted, 'ci_upper_median'].values - medians_10_fw)
lower_err_30_fw = (medians_30_fw - fixed_30yr_washout_agg.loc[diseases_sorted, 'ci_lower_median'].values)
upper_err_30_fw = (fixed_30yr_washout_agg.loc[diseases_sorted, 'ci_upper_median'].values - medians_30_fw)

axes[1, 1].errorbar(medians_10_fw, medians_30_fw,
                    xerr=[lower_err_10_fw, upper_err_10_fw],
                    yerr=[lower_err_30_fw, upper_err_30_fw],
                    fmt='o', alpha=0.6, capsize=3, capthick=1, elinewidth=1, color='green')
for i, disease in enumerate(diseases_sorted):
    axes[1, 1].annotate(disease, (medians_10_fw[i], medians_30_fw[i]),
                       fontsize=7, alpha=0.7, ha='center', va='bottom')
axes[1, 1].plot([0.4, 0.8], [0.4, 0.8], 'r--', linewidth=1, label='y=x')
axes[1, 1].set_xlabel('10-Year AUC (with 95% CI)', fontsize=11)
axes[1, 1].set_ylabel('30-Year AUC (with 95% CI)', fontsize=11)
axes[1, 1].set_title('Fixed Phi (With Washout)', fontsize=13, fontweight='bold')
axes[1, 1].grid(True, alpha=0.3)
axes[1, 1].legend()
axes[1, 1].set_aspect('equal', adjustable='box')

plt.suptitle('10-Year vs 30-Year Predictions (with Aggregated Bootstrap 95% CIs)', 
             fontsize=16, fontweight='bold', y=0.995)
plt.tight_layout()
plt.savefig('10yr_vs_30yr_with_ci_errorbars.png', dpi=300, bbox_inches='tight')
plt.show()

print("✓ Plot with CI error bars saved!")

In [None]:
# Plot with IQR across batches (between-batch variation) instead of aggregated bootstrap CIs

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

def compute_median_and_iqr(results_list):
    """Compute median AUC and IQR across batches"""
    if not results_list:
        return pd.DataFrame()
    
    disease_names_list = [k for k in results_list[0].keys() if k != 'batch_idx']
    
    auc_data = {disease: [] for disease in disease_names_list}
    
    for result in results_list:
        for disease in disease_names_list:
            if disease in result and isinstance(result[disease], dict):
                if 'auc' in result[disease] and not np.isnan(result[disease]['auc']):
                    auc_data[disease].append(result[disease]['auc'])
    
    aggregated = {}
    for disease in disease_names_list:
        if auc_data[disease]:
            aucs = np.array(auc_data[disease])
            aggregated[disease] = {
                'median_auc': np.median(aucs),
                'q25': np.percentile(aucs, 25),
                'q75': np.percentile(aucs, 75),
                'iqr': np.percentile(aucs, 75) - np.percentile(aucs, 25),
                'min': np.min(aucs),
                'max': np.max(aucs),
                'n_batches': len(aucs)
            }
        else:
            aggregated[disease] = {
                'median_auc': np.nan,
                'q25': np.nan,
                'q75': np.nan,
                'iqr': np.nan,
                'min': np.nan,
                'max': np.nan,
                'n_batches': 0
            }
    
    df = pd.DataFrame(aggregated).T
    df = df.sort_values('median_auc', ascending=False)
    return df

# Compute median and IQR for all result sets
joint_10yr_stats = compute_median_and_iqr(joint_10yr_results)
joint_30yr_stats = compute_median_and_iqr(joint_30yr_results)
fixed_10yr_stats = compute_median_and_iqr(fixed_10yr_results)
fixed_30yr_stats = compute_median_and_iqr(fixed_30yr_results)

joint_10yr_washout_stats = compute_median_and_iqr(joint_10yr_results_washout)
joint_30yr_washout_stats = compute_median_and_iqr(joint_30yr_results_washout)
fixed_10yr_washout_stats = compute_median_and_iqr(fixed_10yr_results_washout)
fixed_30yr_washout_stats = compute_median_and_iqr(fixed_30yr_results_washout)

# Get common diseases
all_diseases = set(joint_10yr_stats.index) & set(joint_30yr_stats.index) & set(fixed_10yr_stats.index) & set(fixed_30yr_stats.index)
diseases_sorted = sorted(list(all_diseases))

# Create figure with 4 subplots
fig, axes = plt.subplots(2, 2, figsize=(16, 12))

# Joint Phi - No Washout (using IQR)
medians_10 = joint_10yr_stats.loc[diseases_sorted, 'median_auc'].values
medians_30 = joint_30yr_stats.loc[diseases_sorted, 'median_auc'].values
lower_err_10 = (medians_10 - joint_10yr_stats.loc[diseases_sorted, 'q25'].values)
upper_err_10 = (joint_10yr_stats.loc[diseases_sorted, 'q75'].values - medians_10)
lower_err_30 = (medians_30 - joint_30yr_stats.loc[diseases_sorted, 'q25'].values)
upper_err_30 = (joint_30yr_stats.loc[diseases_sorted, 'q75'].values - medians_30)

axes[0, 0].errorbar(medians_10, medians_30, 
                    xerr=[lower_err_10, upper_err_10], 
                    yerr=[lower_err_30, upper_err_30],
                    fmt='o', alpha=0.6, capsize=3, capthick=1, elinewidth=1)
for i, disease in enumerate(diseases_sorted):
    axes[0, 0].annotate(disease, (medians_10[i], medians_30[i]),
                       fontsize=7, alpha=0.7, ha='center', va='bottom')
axes[0, 0].plot([0.4, 0.8], [0.4, 0.8], 'r--', linewidth=1, label='y=x')
axes[0, 0].set_xlabel('10-Year AUC (IQR across batches)', fontsize=11)
axes[0, 0].set_ylabel('30-Year AUC (IQR across batches)', fontsize=11)
axes[0, 0].set_title('Joint Phi (No Washout)', fontsize=13, fontweight='bold')
axes[0, 0].grid(True, alpha=0.3)
axes[0, 0].legend()
axes[0, 0].set_aspect('equal', adjustable='box')

# Joint Phi - With Washout
medians_10_w = joint_10yr_washout_stats.loc[diseases_sorted, 'median_auc'].values
medians_30_w = joint_30yr_washout_stats.loc[diseases_sorted, 'median_auc'].values
lower_err_10_w = (medians_10_w - joint_10yr_washout_stats.loc[diseases_sorted, 'q25'].values)
upper_err_10_w = (joint_10yr_washout_stats.loc[diseases_sorted, 'q75'].values - medians_10_w)
lower_err_30_w = (medians_30_w - joint_30yr_washout_stats.loc[diseases_sorted, 'q25'].values)
upper_err_30_w = (joint_30yr_washout_stats.loc[diseases_sorted, 'q75'].values - medians_30_w)

axes[0, 1].errorbar(medians_10_w, medians_30_w,
                    xerr=[lower_err_10_w, upper_err_10_w],
                    yerr=[lower_err_30_w, upper_err_30_w],
                    fmt='o', alpha=0.6, capsize=3, capthick=1, elinewidth=1)
for i, disease in enumerate(diseases_sorted):
    axes[0, 1].annotate(disease, (medians_10_w[i], medians_30_w[i]),
                       fontsize=7, alpha=0.7, ha='center', va='bottom')
axes[0, 1].plot([0.4, 0.8], [0.4, 0.8], 'r--', linewidth=1, label='y=x')
axes[0, 1].set_xlabel('10-Year AUC (IQR across batches)', fontsize=11)
axes[0, 1].set_ylabel('30-Year AUC (IQR across batches)', fontsize=11)
axes[0, 1].set_title('Joint Phi (With Washout)', fontsize=13, fontweight='bold')
axes[0, 1].grid(True, alpha=0.3)
axes[0, 1].legend()
axes[0, 1].set_aspect('equal', adjustable='box')

# Fixed Phi - No Washout
medians_10_f = fixed_10yr_stats.loc[diseases_sorted, 'median_auc'].values
medians_30_f = fixed_30yr_stats.loc[diseases_sorted, 'median_auc'].values
lower_err_10_f = (medians_10_f - fixed_10yr_stats.loc[diseases_sorted, 'q25'].values)
upper_err_10_f = (fixed_10yr_stats.loc[diseases_sorted, 'q75'].values - medians_10_f)
lower_err_30_f = (medians_30_f - fixed_30yr_stats.loc[diseases_sorted, 'q25'].values)
upper_err_30_f = (fixed_30yr_stats.loc[diseases_sorted, 'q75'].values - medians_30_f)

axes[1, 0].errorbar(medians_10_f, medians_30_f,
                    xerr=[lower_err_10_f, upper_err_10_f],
                    yerr=[lower_err_30_f, upper_err_30_f],
                    fmt='o', alpha=0.6, capsize=3, capthick=1, elinewidth=1, color='green')
for i, disease in enumerate(diseases_sorted):
    axes[1, 0].annotate(disease, (medians_10_f[i], medians_30_f[i]),
                       fontsize=7, alpha=0.7, ha='center', va='bottom')
axes[1, 0].plot([0.4, 0.8], [0.4, 0.8], 'r--', linewidth=1, label='y=x')
axes[1, 0].set_xlabel('10-Year AUC (IQR across batches)', fontsize=11)
axes[1, 0].set_ylabel('30-Year AUC (IQR across batches)', fontsize=11)
axes[1, 0].set_title('Fixed Phi (No Washout)', fontsize=13, fontweight='bold')
axes[1, 0].grid(True, alpha=0.3)
axes[1, 0].legend()
axes[1, 0].set_aspect('equal', adjustable='box')

# Fixed Phi - With Washout
medians_10_fw = fixed_10yr_washout_stats.loc[diseases_sorted, 'median_auc'].values
medians_30_fw = fixed_30yr_washout_stats.loc[diseases_sorted, 'median_auc'].values
lower_err_10_fw = (medians_10_fw - fixed_10yr_washout_stats.loc[diseases_sorted, 'q25'].values)
upper_err_10_fw = (fixed_10yr_washout_stats.loc[diseases_sorted, 'q75'].values - medians_10_fw)
lower_err_30_fw = (medians_30_fw - fixed_30yr_washout_stats.loc[diseases_sorted, 'q25'].values)
upper_err_30_fw = (fixed_30yr_washout_stats.loc[diseases_sorted, 'q75'].values - medians_30_fw)

axes[1, 1].errorbar(medians_10_fw, medians_30_fw,
                    xerr=[lower_err_10_fw, upper_err_10_fw],
                    yerr=[lower_err_30_fw, upper_err_30_fw],
                    fmt='o', alpha=0.6, capsize=3, capthick=1, elinewidth=1, color='green')
for i, disease in enumerate(diseases_sorted):
    axes[1, 1].annotate(disease, (medians_10_fw[i], medians_30_fw[i]),
                       fontsize=7, alpha=0.7, ha='center', va='bottom')
axes[1, 1].plot([0.4, 0.8], [0.4, 0.8], 'r--', linewidth=1, label='y=x')
axes[1, 1].set_xlabel('10-Year AUC (IQR across batches)', fontsize=11)
axes[1, 1].set_ylabel('30-Year AUC (IQR across batches)', fontsize=11)
axes[1, 1].set_title('Fixed Phi (With Washout)', fontsize=13, fontweight='bold')
axes[1, 1].grid(True, alpha=0.3)
axes[1, 1].legend()
axes[1, 1].set_aspect('equal', adjustable='box')

plt.suptitle('10-Year vs 30-Year Predictions\n(Error bars show IQR across batches)', 
             fontsize=16, fontweight='bold', y=0.995)
plt.tight_layout()
plt.savefig('10yr_vs_30yr_with_iqr.png', dpi=300, bbox_inches='tight')
plt.show()

print("✓ Plot with IQR error bars saved!")

In [None]:
fixed_30yr_stats

In [None]:
joint_30yr_stats

In [None]:
joint_10yr_stats

In [None]:
# Save IQR statistics to CSV files
joint_10yr_stats.to_csv('joint_phi_10yr_median_auc_iqr.csv')
joint_30yr_stats.to_csv('joint_phi_30yr_median_auc_iqr.csv')
fixed_10yr_stats.to_csv('fixed_phi_10yr_median_auc_iqr.csv')
fixed_30yr_stats.to_csv('fixed_phi_30yr_median_auc_iqr.csv')

joint_10yr_washout_stats.to_csv('joint_phi_10yr_median_auc_iqr_washout.csv')
joint_30yr_washout_stats.to_csv('joint_phi_30yr_median_auc_iqr_washout.csv')
fixed_10yr_washout_stats.to_csv('fixed_phi_10yr_median_auc_iqr_washout.csv')
fixed_30yr_washout_stats.to_csv('fixed_phi_30yr_median_auc_iqr_washout.csv')

print("✓ IQR statistics saved to CSV files")

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Load the washout summary tables
washout_fixed = pd.read_csv('/Users/sarahurbut/aladynoulli2/claudefile/output/washout_summary_table.csv')
washout_joint = pd.read_csv('/Users/sarahurbut/aladynoulli2/claudefile/output/washout_summary_table_jointest.csv')

# Calculate AUC drops and retention ratios FIRST
washout_fixed['drop_0_to_1yr'] = washout_fixed['0yr_AUC'] - washout_fixed['1yr_AUC']
washout_fixed['drop_0_to_2yr'] = washout_fixed['0yr_AUC'] - washout_fixed['2yr_AUC']
washout_fixed['retention_1yr'] = washout_fixed['1yr_AUC'] / washout_fixed['0yr_AUC']
washout_fixed['retention_2yr'] = washout_fixed['2yr_AUC'] / washout_fixed['0yr_AUC']

washout_joint['drop_0_to_1yr'] = washout_joint['0yr_AUC'] - washout_joint['1yr_AUC']
washout_joint['drop_0_to_2yr'] = washout_joint['0yr_AUC'] - washout_joint['2yr_AUC']
washout_joint['retention_1yr'] = washout_joint['1yr_AUC'] / washout_joint['0yr_AUC']
washout_joint['retention_2yr'] = washout_joint['2yr_AUC'] / washout_joint['0yr_AUC']

# Get common diseases
diseases_common = set(washout_fixed['Disease']) & set(washout_joint['Disease'])
diseases_sorted = sorted(list(diseases_common))

# Create comparison plots
fig, axes = plt.subplots(2, 2, figsize=(16, 12))

# 1. AUC by washout period (line plot for ALL diseases)
for disease in diseases_sorted:
    fixed_row = washout_fixed[washout_fixed['Disease'] == disease].iloc[0]
    joint_row = washout_joint[washout_joint['Disease'] == disease].iloc[0]
    
    axes[0, 0].plot([0, 1, 2], 
                    [fixed_row['0yr_AUC'], fixed_row['1yr_AUC'], fixed_row['2yr_AUC']],
                    marker='o', alpha=0.4, color='blue', linewidth=1, markersize=4)
    axes[0, 0].plot([0, 1, 2], 
                    [joint_row['0yr_AUC'], joint_row['1yr_AUC'], joint_row['2yr_AUC']],
                    marker='s', alpha=0.4, color='red', linewidth=1, markersize=4)

# Add mean lines
axes[0, 0].plot([0, 1, 2], 
                [washout_fixed['0yr_AUC'].mean(), 
                 washout_fixed['1yr_AUC'].mean(), 
                 washout_fixed['2yr_AUC'].mean()],
                marker='o', label='Fixed Phi (mean)', linewidth=3, markersize=10, color='darkblue')
axes[0, 0].plot([0, 1, 2], 
                [washout_joint['0yr_AUC'].mean(), 
                 washout_joint['1yr_AUC'].mean(), 
                 washout_joint['2yr_AUC'].mean()],
                marker='s', label='Joint Phi (mean)', linewidth=3, markersize=10, color='darkred')

axes[0, 0].set_xlabel('Washout Period (years)', fontsize=12)
axes[0, 0].set_ylabel('AUC', fontsize=12)
axes[0, 0].set_title('AUC by Washout Period (All diseases + mean)', fontsize=13, fontweight='bold')
axes[0, 0].set_xticks([0, 1, 2])
axes[0, 0].set_xticklabels(['0yr (immediate)', '1yr washout', '2yr washout'])
axes[0, 0].grid(True, alpha=0.3)
axes[0, 0].legend()

# 2. AUC drop comparison (0yr to 1yr) - ALL diseases
fixed_drops = washout_fixed.set_index('Disease').loc[diseases_sorted, 'drop_0_to_1yr']
joint_drops = washout_joint.set_index('Disease').loc[diseases_sorted, 'drop_0_to_1yr']

axes[0, 1].scatter(fixed_drops, joint_drops, alpha=0.6, s=100)
axes[0, 1].plot([0, 0.5], [0, 0.5], 'r--', linewidth=1, label='y=x')
for i, disease in enumerate(diseases_sorted):
    axes[0, 1].annotate(disease, (fixed_drops.iloc[i], joint_drops.iloc[i]),
                       fontsize=7, alpha=0.7, ha='center', va='bottom')
axes[0, 1].set_xlabel('Fixed Phi: AUC Drop (0yr → 1yr)', fontsize=11)
axes[0, 1].set_ylabel('Joint Phi: AUC Drop (0yr → 1yr)', fontsize=11)
axes[0, 1].set_title('AUC Drop Comparison (1-year washout)', fontsize=13, fontweight='bold')
axes[0, 1].grid(True, alpha=0.3)
axes[0, 1].legend()

# 3. Retention ratio (1yr washout) - ALL diseases
fixed_retention = washout_fixed.set_index('Disease').loc[diseases_sorted, 'retention_1yr']
joint_retention = washout_joint.set_index('Disease').loc[diseases_sorted, 'retention_1yr']

axes[1, 0].scatter(fixed_retention, joint_retention, alpha=0.6, s=100, color='green')
axes[1, 0].plot([0.4, 1.0], [0.4, 1.0], 'r--', linewidth=1, label='y=x')
axes[1, 0].axhline(0.9, color='gray', linestyle=':', alpha=0.5, label='90% retention')
axes[1, 0].axvline(0.9, color='gray', linestyle=':', alpha=0.5)
for i, disease in enumerate(diseases_sorted):
    axes[1, 0].annotate(disease, (fixed_retention.iloc[i], joint_retention.iloc[i]),
                       fontsize=7, alpha=0.7, ha='center', va='bottom')
axes[1, 0].set_xlabel('Fixed Phi: Retention Ratio (1yr/0yr)', fontsize=11)
axes[1, 0].set_ylabel('Joint Phi: Retention Ratio (1yr/0yr)', fontsize=11)
axes[1, 0].set_title('AUC Retention with 1-year Washout', fontsize=13, fontweight='bold')
axes[1, 0].grid(True, alpha=0.3)
axes[1, 0].legend()
axes[1, 0].set_xlim([0.4, 1.0])
axes[1, 0].set_ylim([0.4, 1.0])

# 4. Bar chart: All diseases by AUC drop (1yr washout) - sorted
washout_fixed_sorted = washout_fixed.sort_values('drop_0_to_1yr', ascending=False)
washout_joint_sorted = washout_joint.sort_values('drop_0_to_1yr', ascending=False)

x_pos = np.arange(len(diseases_sorted))
width = 0.35

# Align by disease name
fixed_drops_aligned = washout_fixed.set_index('Disease').loc[diseases_sorted, 'drop_0_to_1yr']
joint_drops_aligned = washout_joint.set_index('Disease').loc[diseases_sorted, 'drop_0_to_1yr']

# Sort by fixed phi drop
sort_idx = np.argsort(fixed_drops_aligned.values)[::-1]
diseases_sorted_by_drop = [diseases_sorted[i] for i in sort_idx]
fixed_drops_sorted = fixed_drops_aligned.loc[diseases_sorted_by_drop]
joint_drops_sorted = joint_drops_aligned.loc[diseases_sorted_by_drop]

axes[1, 1].bar(x_pos - width/2, fixed_drops_sorted.values, 
               width, label='Fixed Phi', alpha=0.7)
axes[1, 1].bar(x_pos + width/2, joint_drops_sorted.values, 
               width, label='Joint Phi', alpha=0.7)
axes[1, 1].set_xlabel('Disease (sorted by Fixed Phi drop)', fontsize=11)
axes[1, 1].set_ylabel('AUC Drop (0yr → 1yr)', fontsize=11)
axes[1, 1].set_title('All Diseases: AUC Drop with 1-year Washout', fontsize=13, fontweight='bold')
axes[1, 1].set_xticks(x_pos)
axes[1, 1].set_xticklabels(diseases_sorted_by_drop, rotation=45, ha='right', fontsize=8)
axes[1, 1].grid(True, alpha=0.3, axis='y')
axes[1, 1].legend()

plt.suptitle('Washout Analysis: Fixed Phi vs Joint Phi\n(1-year predictions with shifting windows)', 
             fontsize=16, fontweight='bold', y=0.995)
plt.tight_layout()
plt.savefig('washout_comparison_fixed_vs_joint.png', dpi=300, bbox_inches='tight')
plt.show()

# Summary statistics
print("\n" + "="*80)
print("WASHOUT ANALYSIS SUMMARY")
print("="*80)
print(f"\nMean AUC by washout period:")
print(f"  Fixed Phi: 0yr={washout_fixed['0yr_AUC'].mean():.3f}, "
      f"1yr={washout_fixed['1yr_AUC'].mean():.3f}, "
      f"2yr={washout_fixed['2yr_AUC'].mean():.3f}")
print(f"  Joint Phi: 0yr={washout_joint['0yr_AUC'].mean():.3f}, "
      f"1yr={washout_joint['1yr_AUC'].mean():.3f}, "
      f"2yr={washout_joint['2yr_AUC'].mean():.3f}")

print(f"\nMean AUC drop (0yr → 1yr):")
print(f"  Fixed Phi: {washout_fixed['drop_0_to_1yr'].mean():.3f} ± {washout_fixed['drop_0_to_1yr'].std():.3f}")
print(f"  Joint Phi: {washout_joint['drop_0_to_1yr'].mean():.3f} ± {washout_joint['drop_0_to_1yr'].std():.3f}")

print(f"\nMean retention ratio (1yr/0yr):")
print(f"  Fixed Phi: {washout_fixed['retention_1yr'].mean():.3f} ± {washout_fixed['retention_1yr'].std():.3f}")
print(f"  Joint Phi: {washout_joint['retention_1yr'].mean():.3f} ± {washout_joint['retention_1yr'].std():.3f}")

print(f"\nDiseases with >20% AUC drop (1yr washout):")
print(f"  Fixed Phi: {len(washout_fixed[washout_fixed['drop_0_to_1yr'] > 0.2])} diseases")
print(f"  Joint Phi: {len(washout_joint[washout_joint['drop_0_to_1yr'] > 0.2])} diseases")
print("="*80)

In [None]:
washout_fixed

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Load washout tables (1-year predictions with shifting windows)
washout_fixed = pd.read_csv('/Users/sarahurbut/aladynoulli2/claudefile/output/washout_summary_table.csv')
washout_joint = pd.read_csv('/Users/sarahurbut/aladynoulli2/claudefile/output/washout_summary_table_jointest.csv')

# Calculate drops for washout tables
washout_fixed['drop_0_to_1yr'] = washout_fixed['0yr_AUC'] - washout_fixed['1yr_AUC']
washout_joint['drop_0_to_1yr'] = washout_joint['0yr_AUC'] - washout_joint['1yr_AUC']
washout_fixed['retention_1yr'] = washout_fixed['1yr_AUC'] / washout_fixed['0yr_AUC']
washout_joint['retention_1yr'] = washout_joint['1yr_AUC'] / washout_joint['0yr_AUC']

# Load 10/30-year results (from lifetime.ipynb IQR stats)
joint_10yr = pd.read_csv('joint_phi_10yr_median_auc_iqr.csv')
joint_30yr = pd.read_csv('joint_phi_30yr_median_auc_iqr.csv')
fixed_10yr = pd.read_csv('fixed_phi_10yr_median_auc_iqr.csv')
fixed_30yr = pd.read_csv('fixed_phi_30yr_median_auc_iqr.csv')

joint_10yr_washout = pd.read_csv('joint_phi_10yr_median_auc_iqr_washout.csv')
joint_30yr_washout = pd.read_csv('joint_phi_30yr_median_auc_iqr_washout.csv')
fixed_10yr_washout = pd.read_csv('fixed_phi_10yr_median_auc_iqr_washout.csv')
fixed_30yr_washout = pd.read_csv('fixed_phi_30yr_median_auc_iqr_washout.csv')

# Calculate drops for 10/30-year predictions
joint_10yr['drop_washout'] = joint_10yr['median_auc'] - joint_10yr_washout['median_auc']
joint_30yr['drop_washout'] = joint_30yr['median_auc'] - joint_30yr_washout['median_auc']
fixed_10yr['drop_washout'] = fixed_10yr['median_auc'] - fixed_10yr_washout['median_auc']
fixed_30yr['drop_washout'] = fixed_30yr['median_auc'] - fixed_30yr_washout['median_auc']

joint_10yr['retention_washout'] = joint_10yr_washout['median_auc'] / joint_10yr['median_auc']
joint_30yr['retention_washout'] = joint_30yr_washout['median_auc'] / joint_30yr['median_auc']
fixed_10yr['retention_washout'] = fixed_10yr_washout['median_auc'] / fixed_10yr['median_auc']
fixed_30yr['retention_washout'] = fixed_30yr_washout['median_auc'] / fixed_30yr['median_auc']

# Get common diseases across all analyses
diseases_1yr = set(washout_fixed['Disease']) & set(washout_joint['Disease'])
diseases_10yr = set(joint_10yr.index) & set(fixed_10yr.index)
diseases_common = diseases_1yr & diseases_10yr
diseases_sorted = sorted(list(diseases_common))

print(f"Found {len(diseases_common)} common diseases")
print(f"Diseases: {diseases_sorted}")

# Prepare comparison data
comparison_data = []
for disease in diseases_sorted:
    try:
        # 1-year predictions (shifting window)
        fixed_1yr_row = washout_fixed[washout_fixed['Disease'] == disease]
        joint_1yr_row = washout_joint[washout_joint['Disease'] == disease]
        
        if len(fixed_1yr_row) == 0 or len(joint_1yr_row) == 0:
            continue
            
        fixed_1yr_drop = fixed_1yr_row['drop_0_to_1yr'].values[0]
        joint_1yr_drop = joint_1yr_row['drop_0_to_1yr'].values[0]
        
        # 10/30-year predictions (true washout)
        if disease not in joint_10yr.index or disease not in fixed_10yr.index:
            continue
            
        joint_10yr_drop = joint_10yr.loc[disease, 'drop_washout']
        joint_30yr_drop = joint_30yr.loc[disease, 'drop_washout']
        fixed_10yr_drop = fixed_10yr.loc[disease, 'drop_washout']
        fixed_30yr_drop = fixed_30yr.loc[disease, 'drop_washout']
        
        comparison_data.append({
            'Disease': disease,
            'Fixed_1yr_drop': fixed_1yr_drop,
            'Fixed_10yr_drop': fixed_10yr_drop,
            'Fixed_30yr_drop': fixed_30yr_drop,
            'Joint_1yr_drop': joint_1yr_drop,
            'Joint_10yr_drop': joint_10yr_drop,
            'Joint_30yr_drop': joint_30yr_drop
        })
    except Exception as e:
        print(f"Error processing {disease}: {e}")
        continue

if len(comparison_data) == 0:
    print("ERROR: No common diseases found. Check disease name matching.")
    print(f"1yr diseases: {sorted(list(diseases_1yr))[:5]}")
    print(f"10yr diseases: {sorted(list(diseases_10yr))[:5]}")
else:
    comp_df = pd.DataFrame(comparison_data)
    print(f"\nComparison DataFrame shape: {comp_df.shape}")
    print(f"Columns: {comp_df.columns.tolist()}")
    
    # Create comparison plots
    fig, axes = plt.subplots(2, 2, figsize=(18, 12))
    
    # Plot 1: Washout effect by prediction horizon (Fixed Phi)
    x_pos = np.arange(len(comp_df))
    width = 0.25
    axes[0, 0].bar(x_pos - width, comp_df['Fixed_1yr_drop'], width, label='1yr pred (shift)', alpha=0.7)
    axes[0, 0].bar(x_pos, comp_df['Fixed_10yr_drop'], width, label='10yr pred (washout)', alpha=0.7)
    axes[0, 0].bar(x_pos + width, comp_df['Fixed_30yr_drop'], width, label='30yr pred (washout)', alpha=0.7)
    axes[0, 0].set_xlabel('Disease', fontsize=11)
    axes[0, 0].set_ylabel('AUC Drop with Washout', fontsize=11)
    axes[0, 0].set_title('Fixed Phi: Washout Effect by Prediction Horizon', fontsize=13, fontweight='bold')
    axes[0, 0].set_xticks(x_pos)
    axes[0, 0].set_xticklabels(comp_df['Disease'], rotation=45, ha='right', fontsize=8)
    axes[0, 0].grid(True, alpha=0.3, axis='y')
    axes[0, 0].legend()
    axes[0, 0].axhline(0, color='black', linewidth=0.5)
    
    # Plot 2: Washout effect by prediction horizon (Joint Phi)
    axes[0, 1].bar(x_pos - width, comp_df['Joint_1yr_drop'], width, label='1yr pred (shift)', alpha=0.7, color='red')
    axes[0, 1].bar(x_pos, comp_df['Joint_10yr_drop'], width, label='10yr pred (washout)', alpha=0.7, color='red')
    axes[0, 1].bar(x_pos + width, comp_df['Joint_30yr_drop'], width, label='30yr pred (washout)', alpha=0.7, color='red')
    axes[0, 1].set_xlabel('Disease', fontsize=11)
    axes[0, 1].set_ylabel('AUC Drop with Washout', fontsize=11)
    axes[0, 1].set_title('Joint Phi: Washout Effect by Prediction Horizon', fontsize=13, fontweight='bold')
    axes[0, 1].set_xticks(x_pos)
    axes[0, 1].set_xticklabels(comp_df['Disease'], rotation=45, ha='right', fontsize=8)
    axes[0, 1].grid(True, alpha=0.3, axis='y')
    axes[0, 1].legend()
    axes[0, 1].axhline(0, color='black', linewidth=0.5)
    
    # Plot 3: Scatter: 1yr drop vs 10yr drop (Fixed Phi)
    axes[1, 0].scatter(comp_df['Fixed_1yr_drop'], comp_df['Fixed_10yr_drop'], alpha=0.6, s=100)
    axes[1, 0].plot([-0.1, 0.5], [-0.1, 0.5], 'r--', linewidth=1, label='y=x')
    for i, disease in enumerate(comp_df['Disease']):
        axes[1, 0].annotate(disease, (comp_df['Fixed_1yr_drop'].iloc[i], comp_df['Fixed_10yr_drop'].iloc[i]),
                           fontsize=7, alpha=0.7, ha='center', va='bottom')
    axes[1, 0].set_xlabel('1yr Prediction Drop (shifting window)', fontsize=11)
    axes[1, 0].set_ylabel('10yr Prediction Drop (true washout)', fontsize=11)
    axes[1, 0].set_title('Fixed Phi: Washout Effect Comparison', fontsize=13, fontweight='bold')
    axes[1, 0].grid(True, alpha=0.3)
    axes[1, 0].legend()
    
    # Plot 4: Scatter: 1yr drop vs 10yr drop (Joint Phi)
    axes[1, 1].scatter(comp_df['Joint_1yr_drop'], comp_df['Joint_10yr_drop'], alpha=0.6, s=100, color='red')
    axes[1, 1].plot([-0.1, 0.5], [-0.1, 0.5], 'r--', linewidth=1, label='y=x')
    for i, disease in enumerate(comp_df['Disease']):
        axes[1, 1].annotate(disease, (comp_df['Joint_1yr_drop'].iloc[i], comp_df['Joint_10yr_drop'].iloc[i]),
                           fontsize=7, alpha=0.7, ha='center', va='bottom')
    axes[1, 1].set_xlabel('1yr Prediction Drop (shifting window)', fontsize=11)
    axes[1, 1].set_ylabel('10yr Prediction Drop (true washout)', fontsize=11)
    axes[1, 1].set_title('Joint Phi: Washout Effect Comparison', fontsize=13, fontweight='bold')
    axes[1, 1].grid(True, alpha=0.3)
    axes[1, 1].legend()
    
    plt.suptitle('Washout Effect: 1-year (shifting) vs 10/30-year (true washout) Predictions', 
                 fontsize=16, fontweight='bold', y=0.995)
    plt.tight_layout()
    plt.savefig('washout_comparison_across_horizons.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    # Summary statistics
    print("\n" + "="*80)
    print("WASHOUT EFFECT COMPARISON ACROSS PREDICTION HORIZONS")
    print("="*80)
    
    print("\nMean AUC drop by prediction horizon:")
    print(f"  Fixed Phi - 1yr (shift): {comp_df['Fixed_1yr_drop'].mean():.3f} ± {comp_df['Fixed_1yr_drop'].std():.3f}")
    print(f"  Fixed Phi - 10yr (washout): {comp_df['Fixed_10yr_drop'].mean():.3f} ± {comp_df['Fixed_10yr_drop'].std():.3f}")
    print(f"  Fixed Phi - 30yr (washout): {comp_df['Fixed_30yr_drop'].mean():.3f} ± {comp_df['Fixed_30yr_drop'].std():.3f}")
    print(f"  Joint Phi - 1yr (shift): {comp_df['Joint_1yr_drop'].mean():.3f} ± {comp_df['Joint_1yr_drop'].std():.3f}")
    print(f"  Joint Phi - 10yr (washout): {comp_df['Joint_10yr_drop'].mean():.3f} ± {comp_df['Joint_10yr_drop'].std():.3f}")
    print(f"  Joint Phi - 30yr (washout): {comp_df['Joint_30yr_drop'].mean():.3f} ± {comp_df['Joint_30yr_drop'].std():.3f}")
    
    print("\n" + "="*80)
    print("Summary table saved to 'washout_comparison_table.csv'")
    comp_df.to_csv('washout_comparison_table.csv', index=False)
    print("="*80)

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Load washout tables (1-year predictions with shifting windows)
washout_fixed = pd.read_csv('/Users/sarahurbut/aladynoulli2/claudefile/output/washout_summary_table.csv')
washout_joint = pd.read_csv('/Users/sarahurbut/aladynoulli2/claudefile/output/washout_summary_table_jointest.csv')

# Calculate drops for washout tables
washout_fixed['drop_0_to_1yr'] = washout_fixed['0yr_AUC'] - washout_fixed['1yr_AUC']
washout_joint['drop_0_to_1yr'] = washout_joint['0yr_AUC'] - washout_joint['1yr_AUC']
washout_fixed['retention_1yr'] = washout_fixed['1yr_AUC'] / washout_fixed['0yr_AUC']
washout_joint['retention_1yr'] = washout_joint['1yr_AUC'] / washout_joint['0yr_AUC']

# Load 10/30-year results (from lifetime.ipynb IQR stats)
joint_10yr = pd.read_csv('joint_phi_10yr_median_auc_iqr.csv')
joint_30yr = pd.read_csv('joint_phi_30yr_median_auc_iqr.csv')
fixed_10yr = pd.read_csv('fixed_phi_10yr_median_auc_iqr.csv')
fixed_30yr = pd.read_csv('fixed_phi_30yr_median_auc_iqr.csv')

joint_10yr_washout = pd.read_csv('joint_phi_10yr_median_auc_iqr_washout.csv')
joint_30yr_washout = pd.read_csv('joint_phi_30yr_median_auc_iqr_washout.csv')
fixed_10yr_washout = pd.read_csv('fixed_phi_10yr_median_auc_iqr_washout.csv')
fixed_30yr_washout = pd.read_csv('fixed_phi_30yr_median_auc_iqr_washout.csv')

# Fix: Set disease names as index for 10/30-year DataFrames
# Check if there's an 'Unnamed: 0' column or if index has disease names
for df_name, df in [('joint_10yr', joint_10yr), ('joint_30yr', joint_30yr), 
                    ('fixed_10yr', fixed_10yr), ('fixed_30yr', fixed_30yr),
                    ('joint_10yr_washout', joint_10yr_washout), ('joint_30yr_washout', joint_30yr_washout),
                    ('fixed_10yr_washout', fixed_10yr_washout), ('fixed_30yr_washout', fixed_30yr_washout)]:
    if 'Unnamed: 0' in df.columns:
        df.set_index('Unnamed: 0', inplace=True)
    elif df.index.name is None and df.index.dtype == 'int64':
        # Index is numeric, check if first column has disease names
        first_col = df.columns[0]
        if df[first_col].dtype == 'object':  # Likely disease names
            df.set_index(first_col, inplace=True)

# Calculate drops for 10/30-year predictions
joint_10yr['drop_washout'] = joint_10yr['median_auc'] - joint_10yr_washout['median_auc']
joint_30yr['drop_washout'] = joint_30yr['median_auc'] - joint_30yr_washout['median_auc']
fixed_10yr['drop_washout'] = fixed_10yr['median_auc'] - fixed_10yr_washout['median_auc']
fixed_30yr['drop_washout'] = fixed_30yr['median_auc'] - fixed_30yr_washout['median_auc']

joint_10yr['retention_washout'] = joint_10yr_washout['median_auc'] / joint_10yr['median_auc']
joint_30yr['retention_washout'] = joint_30yr_washout['median_auc'] / joint_30yr['median_auc']
fixed_10yr['retention_washout'] = fixed_10yr_washout['median_auc'] / fixed_10yr['median_auc']
fixed_30yr['retention_washout'] = fixed_30yr_washout['median_auc'] / fixed_30yr['median_auc']

# Get common diseases - now using index for 10/30-year DataFrames
diseases_1yr = set(washout_fixed['Disease']) & set(washout_joint['Disease'])
diseases_10yr = set(joint_10yr.index) & set(fixed_10yr.index)
diseases_common = diseases_1yr & diseases_10yr
diseases_sorted = sorted(list(diseases_common))

print(f"Found {len(diseases_common)} common diseases")
print(f"Diseases: {diseases_sorted[:10]}...")  # Show first 10

# Prepare comparison data
comparison_data = []
for disease in diseases_sorted:
    try:
        # 1-year predictions (shifting window)
        fixed_1yr_row = washout_fixed[washout_fixed['Disease'] == disease]
        joint_1yr_row = washout_joint[washout_joint['Disease'] == disease]
        
        if len(fixed_1yr_row) == 0 or len(joint_1yr_row) == 0:
            continue
            
        fixed_1yr_drop = fixed_1yr_row['drop_0_to_1yr'].values[0]
        joint_1yr_drop = joint_1yr_row['drop_0_to_1yr'].values[0]
        
        # 10/30-year predictions (true washout) - now using index
        if disease not in joint_10yr.index or disease not in fixed_10yr.index:
            continue
            
        joint_10yr_drop = joint_10yr.loc[disease, 'drop_washout']
        joint_30yr_drop = joint_30yr.loc[disease, 'drop_washout']
        fixed_10yr_drop = fixed_10yr.loc[disease, 'drop_washout']
        fixed_30yr_drop = fixed_30yr.loc[disease, 'drop_washout']
        
        comparison_data.append({
            'Disease': disease,
            'Fixed_1yr_drop': fixed_1yr_drop,
            'Fixed_10yr_drop': fixed_10yr_drop,
            'Fixed_30yr_drop': fixed_30yr_drop,
            'Joint_1yr_drop': joint_1yr_drop,
            'Joint_10yr_drop': joint_10yr_drop,
            'Joint_30yr_drop': joint_30yr_drop
        })
    except Exception as e:
        print(f"Error processing {disease}: {e}")
        continue

if len(comparison_data) == 0:
    print("ERROR: No common diseases found after fixing index.")
    print(f"1yr diseases (first 5): {sorted(list(diseases_1yr))[:5]}")
    print(f"10yr diseases (first 5): {sorted(list(diseases_10yr))[:5]}")
    print(f"\njoint_10yr index sample: {list(joint_10yr.index[:5])}")
    print(f"joint_10yr columns: {joint_10yr.columns.tolist()}")
else:
    comp_df = pd.DataFrame(comparison_data)
    print(f"\nComparison DataFrame shape: {comp_df.shape}")
    
    # Create comparison plots
    fig, axes = plt.subplots(2, 2, figsize=(18, 12))
    
    # Plot 1: Washout effect by prediction horizon (Fixed Phi)
    x_pos = np.arange(len(comp_df))
    width = 0.25
    axes[0, 0].bar(x_pos - width, comp_df['Fixed_1yr_drop'], width, label='1yr pred (shift)', alpha=0.7)
    axes[0, 0].bar(x_pos, comp_df['Fixed_10yr_drop'], width, label='10yr pred (washout)', alpha=0.7)
    axes[0, 0].bar(x_pos + width, comp_df['Fixed_30yr_drop'], width, label='30yr pred (washout)', alpha=0.7)
    axes[0, 0].set_xlabel('Disease', fontsize=11)
    axes[0, 0].set_ylabel('AUC Drop with Washout', fontsize=11)
    axes[0, 0].set_title('Fixed Phi: Washout Effect by Prediction Horizon', fontsize=13, fontweight='bold')
    axes[0, 0].set_xticks(x_pos)
    axes[0, 0].set_xticklabels(comp_df['Disease'], rotation=45, ha='right', fontsize=8)
    axes[0, 0].grid(True, alpha=0.3, axis='y')
    axes[0, 0].legend()
    axes[0, 0].axhline(0, color='black', linewidth=0.5)
    
    # Plot 2: Washout effect by prediction horizon (Joint Phi)
    axes[0, 1].bar(x_pos - width, comp_df['Joint_1yr_drop'], width, label='1yr pred (shift)', alpha=0.7, color='red')
    axes[0, 1].bar(x_pos, comp_df['Joint_10yr_drop'], width, label='10yr pred (washout)', alpha=0.7, color='red')
    axes[0, 1].bar(x_pos + width, comp_df['Joint_30yr_drop'], width, label='30yr pred (washout)', alpha=0.7, color='red')
    axes[0, 1].set_xlabel('Disease', fontsize=11)
    axes[0, 1].set_ylabel('AUC Drop with Washout', fontsize=11)
    axes[0, 1].set_title('Joint Phi: Washout Effect by Prediction Horizon', fontsize=13, fontweight='bold')
    axes[0, 1].set_xticks(x_pos)
    axes[0, 1].set_xticklabels(comp_df['Disease'], rotation=45, ha='right', fontsize=8)
    axes[0, 1].grid(True, alpha=0.3, axis='y')
    axes[0, 1].legend()
    axes[0, 1].axhline(0, color='black', linewidth=0.5)
    
    # Plot 3: Scatter: 1yr drop vs 10yr drop (Fixed Phi)
    axes[1, 0].scatter(comp_df['Fixed_1yr_drop'], comp_df['Fixed_10yr_drop'], alpha=0.6, s=100)
    axes[1, 0].plot([-0.1, 0.5], [-0.1, 0.5], 'r--', linewidth=1, label='y=x')
    for i, disease in enumerate(comp_df['Disease']):
        axes[1, 0].annotate(disease, (comp_df['Fixed_1yr_drop'].iloc[i], comp_df['Fixed_10yr_drop'].iloc[i]),
                           fontsize=7, alpha=0.7, ha='center', va='bottom')
    axes[1, 0].set_xlabel('1yr Prediction Drop (shifting window)', fontsize=11)
    axes[1, 0].set_ylabel('10yr Prediction Drop (true washout)', fontsize=11)
    axes[1, 0].set_title('Fixed Phi: Washout Effect Comparison', fontsize=13, fontweight='bold')
    axes[1, 0].grid(True, alpha=0.3)
    axes[1, 0].legend()
    
    # Plot 4: Scatter: 1yr drop vs 10yr drop (Joint Phi)
    axes[1, 1].scatter(comp_df['Joint_1yr_drop'], comp_df['Joint_10yr_drop'], alpha=0.6, s=100, color='red')
    axes[1, 1].plot([-0.1, 0.5], [-0.1, 0.5], 'r--', linewidth=1, label='y=x')
    for i, disease in enumerate(comp_df['Disease']):
        axes[1, 1].annotate(disease, (comp_df['Joint_1yr_drop'].iloc[i], comp_df['Joint_10yr_drop'].iloc[i]),
                           fontsize=7, alpha=0.7, ha='center', va='bottom')
    axes[1, 1].set_xlabel('1yr Prediction Drop (shifting window)', fontsize=11)
    axes[1, 1].set_ylabel('10yr Prediction Drop (true washout)', fontsize=11)
    axes[1, 1].set_title('Joint Phi: Washout Effect Comparison', fontsize=13, fontweight='bold')
    axes[1, 1].grid(True, alpha=0.3)
    axes[1, 1].legend()
    
    plt.suptitle('Washout Effect: 1-year (shifting) vs 10/30-year (true washout) Predictions', 
                 fontsize=16, fontweight='bold', y=0.995)
    plt.tight_layout()
    plt.savefig('washout_comparison_across_horizons.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    # Summary statistics
    print("\n" + "="*80)
    print("WASHOUT EFFECT COMPARISON ACROSS PREDICTION HORIZONS")
    print("="*80)
    
    print("\nMean AUC drop by prediction horizon:")
    print(f"  Fixed Phi - 1yr (shift): {comp_df['Fixed_1yr_drop'].mean():.3f} ± {comp_df['Fixed_1yr_drop'].std():.3f}")
    print(f"  Fixed Phi - 10yr (washout): {comp_df['Fixed_10yr_drop'].mean():.3f} ± {comp_df['Fixed_10yr_drop'].std():.3f}")
    print(f"  Fixed Phi - 30yr (washout): {comp_df['Fixed_30yr_drop'].mean():.3f} ± {comp_df['Fixed_30yr_drop'].std():.3f}")
    print(f"  Joint Phi - 1yr (shift): {comp_df['Joint_1yr_drop'].mean():.3f} ± {comp_df['Joint_1yr_drop'].std():.3f}")
    print(f"  Joint Phi - 10yr (washout): {comp_df['Joint_10yr_drop'].mean():.3f} ± {comp_df['Joint_10yr_drop'].std():.3f}")
    print(f"  Joint Phi - 30yr (washout): {comp_df['Joint_30yr_drop'].mean():.3f} ± {comp_df['Joint_30yr_drop'].std():.3f}")
    
    print("\n" + "="*80)
    print("Summary table saved to 'washout_comparison_table.csv'")
    comp_df.to_csv('washout_comparison_table.csv', index=False)
    print("="*80)


In [None]:
# 1. Better visualization showing all prediction horizons
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Load all data
washout_fixed = pd.read_csv('/Users/sarahurbut/aladynoulli2/claudefile/output/washout_summary_table.csv')
washout_joint = pd.read_csv('/Users/sarahurbut/aladynoulli2/claudefile/output/washout_summary_table_jointest.csv')

# Load 10/30-year results
joint_10yr = pd.read_csv('joint_phi_10yr_median_auc_iqr.csv')
joint_30yr = pd.read_csv('joint_phi_30yr_median_auc_iqr.csv')
fixed_10yr = pd.read_csv('fixed_phi_10yr_median_auc_iqr.csv')
fixed_30yr = pd.read_csv('fixed_phi_30yr_median_auc_iqr.csv')

joint_10yr_washout = pd.read_csv('joint_phi_10yr_median_auc_iqr_washout.csv')
joint_30yr_washout = pd.read_csv('joint_phi_30yr_median_auc_iqr_washout.csv')
fixed_10yr_washout = pd.read_csv('fixed_phi_10yr_median_auc_iqr_washout.csv')
fixed_30yr_washout = pd.read_csv('fixed_phi_30yr_median_auc_iqr_washout.csv')

# Fix index for 10/30-year DataFrames
for df in [joint_10yr, joint_30yr, fixed_10yr, fixed_30yr, 
           joint_10yr_washout, joint_30yr_washout, fixed_10yr_washout, fixed_30yr_washout]:
    if 'Unnamed: 0' in df.columns:
        df.set_index('Unnamed: 0', inplace=True)

# Get common diseases
diseases_common = set(washout_fixed['Disease']) & set(joint_10yr.index)
diseases_sorted = sorted(list(diseases_common))

# Prepare comprehensive comparison
comparison_all = []
for disease in diseases_sorted:
    try:
        # 1-year predictions (shifting window)
        fixed_1yr = washout_fixed[washout_fixed['Disease'] == disease]['0yr_AUC'].values[0]
        joint_1yr = washout_joint[washout_joint['Disease'] == disease]['0yr_AUC'].values[0]
        
        # 10/30-year predictions (no washout)
        fixed_10yr_auc = fixed_10yr.loc[disease, 'median_auc']
        fixed_30yr_auc = fixed_30yr.loc[disease, 'median_auc']
        joint_10yr_auc = joint_10yr.loc[disease, 'median_auc']
        joint_30yr_auc = joint_30yr.loc[disease, 'median_auc']
        
        # 10/30-year predictions (with washout)
        fixed_10yr_w = fixed_10yr_washout.loc[disease, 'median_auc']
        fixed_30yr_w = fixed_30yr_washout.loc[disease, 'median_auc']
        joint_10yr_w = joint_10yr_washout.loc[disease, 'median_auc']
        joint_30yr_w = joint_30yr_washout.loc[disease, 'median_auc']
        
        comparison_all.append({
            'Disease': disease,
            'Fixed_1yr': fixed_1yr,
            'Fixed_10yr': fixed_10yr_auc,
            'Fixed_10yr_washout': fixed_10yr_w,
            'Fixed_30yr': fixed_30yr_auc,
            'Fixed_30yr_washout': fixed_30yr_w,
            'Joint_1yr': joint_1yr,
            'Joint_10yr': joint_10yr_auc,
            'Joint_10yr_washout': joint_10yr_w,
            'Joint_30yr': joint_30yr_auc,
            'Joint_30yr_washout': joint_30yr_w
        })
    except:
        continue

comp_all_df = pd.DataFrame(comparison_all)

# Create comprehensive visualization
fig, axes = plt.subplots(2, 2, figsize=(20, 14))

# Plot 1: Fixed Phi - All prediction horizons
x_pos = np.arange(len(comp_all_df))
width = 0.15
axes[0, 0].bar(x_pos - 2*width, comp_all_df['Fixed_1yr'], width, label='1yr (shift)', alpha=0.8, color='blue')
axes[0, 0].bar(x_pos - width, comp_all_df['Fixed_10yr'], width, label='10yr (no washout)', alpha=0.8, color='green')
axes[0, 0].bar(x_pos, comp_all_df['Fixed_10yr_washout'], width, label='10yr (washout)', alpha=0.8, color='lightgreen')
axes[0, 0].bar(x_pos + width, comp_all_df['Fixed_30yr'], width, label='30yr (no washout)', alpha=0.8, color='orange')
axes[0, 0].bar(x_pos + 2*width, comp_all_df['Fixed_30yr_washout'], width, label='30yr (washout)', alpha=0.8, color='lightcoral')
axes[0, 0].set_xlabel('Disease', fontsize=11)
axes[0, 0].set_ylabel('AUC', fontsize=11)
axes[0, 0].set_title('Fixed Phi: AUC by Prediction Horizon', fontsize=13, fontweight='bold')
axes[0, 0].set_xticks(x_pos)
axes[0, 0].set_xticklabels(comp_all_df['Disease'], rotation=45, ha='right', fontsize=8)
axes[0, 0].grid(True, alpha=0.3, axis='y')
axes[0, 0].legend(fontsize=9)
axes[0, 0].set_ylim([0.3, 1.0])

# Plot 2: Joint Phi - All prediction horizons
axes[0, 1].bar(x_pos - 2*width, comp_all_df['Joint_1yr'], width, label='1yr (shift)', alpha=0.8, color='darkred')
axes[0, 1].bar(x_pos - width, comp_all_df['Joint_10yr'], width, label='10yr (no washout)', alpha=0.8, color='green')
axes[0, 1].bar(x_pos, comp_all_df['Joint_10yr_washout'], width, label='10yr (washout)', alpha=0.8, color='lightgreen')
axes[0, 1].bar(x_pos + width, comp_all_df['Joint_30yr'], width, label='30yr (no washout)', alpha=0.8, color='orange')
axes[0, 1].bar(x_pos + 2*width, comp_all_df['Joint_30yr_washout'], width, label='30yr (washout)', alpha=0.8, color='lightcoral')
axes[0, 1].set_xlabel('Disease', fontsize=11)
axes[0, 1].set_ylabel('AUC', fontsize=11)
axes[0, 1].set_title('Joint Phi: AUC by Prediction Horizon', fontsize=13, fontweight='bold')
axes[0, 1].set_xticks(x_pos)
axes[0, 1].set_xticklabels(comp_all_df['Disease'], rotation=45, ha='right', fontsize=8)
axes[0, 1].grid(True, alpha=0.3, axis='y')
axes[0, 1].legend(fontsize=9)
axes[0, 1].set_ylim([0.3, 1.0])

# Plot 3: Scatter - 10yr vs 30yr (Fixed Phi, with and without washout)
axes[1, 0].scatter(comp_all_df['Fixed_10yr'], comp_all_df['Fixed_30yr'], 
                   alpha=0.6, s=100, label='No washout', color='blue')
axes[1, 0].scatter(comp_all_df['Fixed_10yr_washout'], comp_all_df['Fixed_30yr_washout'], 
                   alpha=0.6, s=100, label='With washout', color='red', marker='s')
axes[1, 0].plot([0.4, 1.0], [0.4, 1.0], 'k--', linewidth=1, alpha=0.5, label='y=x')
for i, disease in enumerate(comp_all_df['Disease']):
    axes[1, 0].annotate(disease, (comp_all_df['Fixed_10yr'].iloc[i], comp_all_df['Fixed_30yr'].iloc[i]),
                       fontsize=6, alpha=0.6, ha='center', va='bottom')
axes[1, 0].set_xlabel('10-Year AUC', fontsize=11)
axes[1, 0].set_ylabel('30-Year AUC', fontsize=11)
axes[1, 0].set_title('Fixed Phi: 10yr vs 30yr Predictions', fontsize=13, fontweight='bold')
axes[1, 0].grid(True, alpha=0.3)
axes[1, 0].legend()
axes[1, 0].set_aspect('equal', adjustable='box')

# Plot 4: Scatter - 10yr vs 30yr (Joint Phi, with and without washout)
axes[1, 1].scatter(comp_all_df['Joint_10yr'], comp_all_df['Joint_30yr'], 
                   alpha=0.6, s=100, label='No washout', color='blue')
axes[1, 1].scatter(comp_all_df['Joint_10yr_washout'], comp_all_df['Joint_30yr_washout'], 
                   alpha=0.6, s=100, label='With washout', color='red', marker='s')
axes[1, 1].plot([0.4, 1.0], [0.4, 1.0], 'k--', linewidth=1, alpha=0.5, label='y=x')
for i, disease in enumerate(comp_all_df['Disease']):
    axes[1, 1].annotate(disease, (comp_all_df['Joint_10yr'].iloc[i], comp_all_df['Joint_30yr'].iloc[i]),
                       fontsize=6, alpha=0.6, ha='center', va='bottom')
axes[1, 1].set_xlabel('10-Year AUC', fontsize=11)
axes[1, 1].set_ylabel('30-Year AUC', fontsize=11)
axes[1, 1].set_title('Joint Phi: 10yr vs 30yr Predictions', fontsize=13, fontweight='bold')
axes[1, 1].grid(True, alpha=0.3)
axes[1, 1].legend()
axes[1, 1].set_aspect('equal', adjustable='box')

plt.suptitle('Comprehensive Prediction Performance: 1yr, 10yr, and 30yr Horizons', 
             fontsize=16, fontweight='bold', y=0.995)
plt.tight_layout()
plt.savefig('all_prediction_horizons_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

# Save comprehensive table
comp_all_df.to_csv('all_prediction_horizons_comparison.csv', index=False)
print("✓ Comprehensive comparison saved!")

In [None]:
from fig5utils import *
import pandas as pd
import numpy as np
import torch

# Load full pce_df
pce_df_full = pd.read_csv('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/pce_prevent_full.csv')
disease_names = essentials['disease_names']

# Storage for results - SEPARATE variables for each analysis type
# Fixed phi from ENROLLMENT data
fixed_enrollment_10yr_results = []
fixed_enrollment_30yr_results = []
fixed_enrollment_static_10yr_results = []

# Fixed phi from RETROSPECTIVE data
fixed_retrospective_10yr_results = []
fixed_retrospective_30yr_results = []
fixed_retrospective_static_10yr_results = []

# Load full tensors once (shared across both analyses)
if 'Y_full' not in globals():
    Y_full = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/Y_tensor.pt')
if 'E_full' not in globals():
    E_full = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/E_enrollment_full.pt')

# Loop through checkpoints 0-10 (10 batches)
for batch_idx in range(41):
    start_idx = batch_idx * 10000
    end_idx = (batch_idx + 1) * 10000
    
    print(f"\n{'='*80}")
    print(f"Processing batch {batch_idx}: {start_idx} to {end_idx}")
    print(f"{'='*80}")
    
    # Get pce_df subset for this batch
    pce_df_subset = pce_df_full[start_idx:end_idx].copy().reset_index(drop=True)
    
    # Extract batch from full tensors (shared for both analyses)
    Y_batch = Y_full[start_idx:end_idx]
    E_batch = E_full[start_idx:end_idx]
    
    # ===== FIXED PHI FROM ENROLLMENT DATA =====
    fixed_enrollment_ckpt_path = f'/Users/sarahurbut/Library/CloudStorage/Dropbox/enrollment_predictions_fixedphi_ENROLLMENT_pooled/model_enroll_fixedphi_sex_{start_idx}_{end_idx}.pt'
    
    try:
        print(f"\n--- Fixed Phi (ENROLLMENT) ---")
        fixed_ckpt = torch.load(fixed_enrollment_ckpt_path, weights_only=False)
        model.load_state_dict(fixed_ckpt['model_state_dict'])
        
        # Update model.Y and model.N so forward() uses correct patients
        model.Y = torch.tensor(Y_batch, dtype=torch.float32)
        model.N = Y_batch.shape[0]
       
        # 10-year predictions
        print(f"Fixed Phi (ENROLLMENT) - 10 year predictions...")
        fixed_10yr = evaluate_major_diseases_wsex_with_bootstrap_dynamic(
            model, Y_batch, E_batch, disease_names, pce_df_subset, 
            n_bootstraps=100, follow_up_duration_years=10, patient_indices=None
        )
        fixed_10yr['batch_idx'] = batch_idx
        fixed_10yr['analysis_type'] = 'fixed_enrollment'
        fixed_enrollment_10yr_results.append(fixed_10yr)
        
        # 30-year predictions
        print(f"Fixed Phi (ENROLLMENT) - 30 year predictions...")
        fixed_30yr = evaluate_major_diseases_wsex_with_bootstrap_dynamic(
            model, Y_batch, E_batch, disease_names, pce_df_subset, 
            n_bootstraps=100, follow_up_duration_years=30, patient_indices=None
        )
        fixed_30yr['batch_idx'] = batch_idx
        fixed_30yr['analysis_type'] = 'fixed_enrollment'
        fixed_enrollment_30yr_results.append(fixed_30yr)
        
        # Static 10-year predictions (using 1-year score)
        print(f"Fixed Phi (ENROLLMENT) - Static 10 year predictions...")
        fixed_static_10yr = evaluate_major_diseases_wsex_with_bootstrap(
            model=model,
            Y_100k=Y_batch,
            E_100k=E_batch,
            disease_names=disease_names,
            pce_df=pce_df_subset,
            n_bootstraps=100,
            follow_up_duration_years=10,
        )
        fixed_static_10yr['batch_idx'] = batch_idx
        fixed_static_10yr['analysis_type'] = 'fixed_enrollment'
        fixed_enrollment_static_10yr_results.append(fixed_static_10yr)
        
    except FileNotFoundError:
        print(f"Fixed phi (ENROLLMENT) checkpoint not found: {fixed_enrollment_ckpt_path}")
    except Exception as e:
        print(f"Error processing fixed phi (ENROLLMENT) checkpoint {batch_idx}: {e}")
        import traceback
        traceback.print_exc()
    
    # ===== FIXED PHI FROM RETROSPECTIVE DATA =====
    fixed_retrospective_ckpt_path = f'/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/enrollment_predictions_fixedphi_RETROSPECTIVE_pooled/model_enroll_fixedphi_sex_{start_idx}_{end_idx}.pt'
    
    try:
        print(f"\n--- Fixed Phi (RETROSPECTIVE) ---")
        fixed_ckpt = torch.load(fixed_retrospective_ckpt_path, weights_only=False)
        model.load_state_dict(fixed_ckpt['model_state_dict'])
        
        # Update model.Y and model.N so forward() uses correct patients
        model.Y = torch.tensor(Y_batch, dtype=torch.float32)
        model.N = Y_batch.shape[0]
       
        # 10-year predictions
        print(f"Fixed Phi (RETROSPECTIVE) - 10 year predictions...")
        fixed_10yr = evaluate_major_diseases_wsex_with_bootstrap_dynamic(
            model, Y_batch, E_batch, disease_names, pce_df_subset, 
            n_bootstraps=100, follow_up_duration_years=10, patient_indices=None
        )
        fixed_10yr['batch_idx'] = batch_idx
        fixed_10yr['analysis_type'] = 'fixed_retrospective'
        fixed_retrospective_10yr_results.append(fixed_10yr)
        
        # 30-year predictions
        print(f"Fixed Phi (RETROSPECTIVE) - 30 year predictions...")
        fixed_30yr = evaluate_major_diseases_wsex_with_bootstrap_dynamic(
            model, Y_batch, E_batch, disease_names, pce_df_subset, 
            n_bootstraps=100, follow_up_duration_years=30, patient_indices=None
        )
        fixed_30yr['batch_idx'] = batch_idx
        fixed_30yr['analysis_type'] = 'fixed_retrospective'
        fixed_retrospective_30yr_results.append(fixed_30yr)
        
        # Static 10-year predictions (using 1-year score)
        print(f"Fixed Phi (RETROSPECTIVE) - Static 10 year predictions...")
        fixed_static_10yr = evaluate_major_diseases_wsex_with_bootstrap(
            model=model,
            Y_100k=Y_batch,
            E_100k=E_batch,
            disease_names=disease_names,
            pce_df=pce_df_subset,
            n_bootstraps=100,
            follow_up_duration_years=10,
        )
        fixed_static_10yr['batch_idx'] = batch_idx
        fixed_static_10yr['analysis_type'] = 'fixed_retrospective'
        fixed_retrospective_static_10yr_results.append(fixed_static_10yr)
        
    except FileNotFoundError:
        print(f"Fixed phi (RETROSPECTIVE) checkpoint not found: {fixed_retrospective_ckpt_path}")
    except Exception as e:
        print(f"Error processing fixed phi (RETROSPECTIVE) checkpoint {batch_idx}: {e}")
        import traceback
        traceback.print_exc()

print(f"\n{'='*80}")
print("Completed processing all checkpoints!")
print(f"{'='*80}")
print(f"Fixed Enrollment - 10yr: {len(fixed_enrollment_10yr_results)} batches")
print(f"Fixed Enrollment - 30yr: {len(fixed_enrollment_30yr_results)} batches")
print(f"Fixed Retrospective - 10yr: {len(fixed_retrospective_10yr_results)} batches")
print(f"Fixed Retrospective - 30yr: {len(fixed_retrospective_30yr_results)} batches")


In [None]:
# ===== SAVE RESULTS TO DISK (to avoid rerunning long computation) =====
# Paste this cell right after line 157 (after the print statements)

print(f"\n{'='*80}")
print("SAVING RESULTS TO DISK")
print(f"{'='*80}")

results_dir = '/Users/sarahurbut/aladynoulli2/pyScripts/new_oct_revision/new_notebooks/saved_results/'
import os
os.makedirs(results_dir, exist_ok=True)

# Save all 6 result lists
print("\nSaving Fixed Enrollment results...")
torch.save(fixed_enrollment_10yr_results, f'{results_dir}fixed_enrollment_10yr_results.pt')
torch.save(fixed_enrollment_30yr_results, f'{results_dir}fixed_enrollment_30yr_results.pt')
torch.save(fixed_enrollment_static_10yr_results, f'{results_dir}fixed_enrollment_static_10yr_results.pt')

print("Saving Fixed Retrospective results...")
torch.save(fixed_retrospective_10yr_results, f'{results_dir}fixed_retrospective_10yr_results.pt')
torch.save(fixed_retrospective_30yr_results, f'{results_dir}fixed_retrospective_30yr_results.pt')
torch.save(fixed_retrospective_static_10yr_results, f'{results_dir}fixed_retrospective_static_10yr_results.pt')

print(f"\n✓ All results saved to {results_dir}")
print(f"  - fixed_enrollment_10yr_results.pt ({len(fixed_enrollment_10yr_results)} batches)")
print(f"  - fixed_enrollment_30yr_results.pt ({len(fixed_enrollment_30yr_results)} batches)")
print(f"  - fixed_enrollment_static_10yr_results.pt ({len(fixed_enrollment_static_10yr_results)} batches)")
print(f"  - fixed_retrospective_10yr_results.pt ({len(fixed_retrospective_10yr_results)} batches)")
print(f"  - fixed_retrospective_30yr_results.pt ({len(fixed_retrospective_30yr_results)} batches)")
print(f"  - fixed_retrospective_static_10yr_results.pt ({len(fixed_retrospective_static_10yr_results)} batches)")

print(f"\n{'='*80}")
print("To reload later, use:")
print(f"{'='*80}")
print("fixed_enrollment_10yr_results = torch.load('{results_dir}fixed_enrollment_10yr_results.pt')")
print("fixed_enrollment_30yr_results = torch.load('{results_dir}fixed_enrollment_30yr_results.pt')")
print("fixed_enrollment_static_10yr_results = torch.load('{results_dir}fixed_enrollment_static_10yr_results.pt')")
print("fixed_retrospective_10yr_results = torch.load('{results_dir}fixed_retrospective_10yr_results.pt')")
print("fixed_retrospective_30yr_results = torch.load('{results_dir}fixed_retrospective_30yr_results.pt')")
print("fixed_retrospective_static_10yr_results = torch.load('{results_dir}fixed_retrospective_static_10yr_results.pt')")


In [None]:
# ===== AGGREGATE AND SAVE RESULTS =====
print(f"\n{'='*80}")
print("AGGREGATING RESULTS ACROSS ALL BATCHES")
print(f"{'='*80}")

def aggregate_results_to_dataframe(results_list, analysis_name):
    """
    Aggregate results across batches into a DataFrame.
    Each result is a dict with disease names as keys and metrics as values.
    """
    if not results_list:
        print(f"Warning: No results found for {analysis_name}")
        return pd.DataFrame()
    
    # Get all disease names (excluding metadata keys)
    disease_names_list = [k for k in results_list[0].keys() 
                         if k not in ['batch_idx', 'analysis_type']]
    
    # Collect all metrics across batches
    aggregated_data = []
    for disease in disease_names_list:
        aucs = []
        ci_lowers = []
        ci_uppers = []
        n_events_list = []
        event_rates = []
        
        for result in results_list:
            if disease in result and isinstance(result[disease], dict):
                if 'auc' in result[disease] and not np.isnan(result[disease]['auc']):
                    aucs.append(result[disease]['auc'])
                if 'ci_lower' in result[disease] and not np.isnan(result[disease]['ci_lower']):
                    ci_lowers.append(result[disease]['ci_lower'])
                if 'ci_upper' in result[disease] and not np.isnan(result[disease]['ci_upper']):
                    ci_uppers.append(result[disease]['ci_upper'])
                if 'n_events' in result[disease]:
                    n_events_list.append(result[disease]['n_events'])
                if 'event_rate' in result[disease] and result[disease]['event_rate'] is not None:
                    event_rates.append(result[disease]['event_rate'])
        
        if aucs:  # Only add if we have at least one valid AUC
            aggregated_data.append({
                'Disease': disease,
                'AUC_median': np.median(aucs),
                'AUC_mean': np.mean(aucs),
                'AUC_std': np.std(aucs),
                'AUC_min': np.min(aucs),
                'AUC_max': np.max(aucs),
                'CI_lower_median': np.median(ci_lowers) if ci_lowers else np.nan,
                'CI_upper_median': np.median(ci_uppers) if ci_uppers else np.nan,
                'CI_lower_min': np.min(ci_lowers) if ci_lowers else np.nan,
                'CI_upper_max': np.max(ci_uppers) if ci_uppers else np.nan,
                'Total_Events': np.sum(n_events_list) if n_events_list else np.nan,
                'Mean_Event_Rate': np.mean(event_rates) if event_rates else np.nan,
                'N_Batches': len(aucs)
            })
    
    df = pd.DataFrame(aggregated_data)
    if not df.empty:
        df = df.set_index('Disease').sort_values('AUC_median', ascending=False)
    return df

# Aggregate all 6 result lists
print("\nAggregating Fixed Enrollment results...")
fixed_enrollment_10yr_df = aggregate_results_to_dataframe(fixed_enrollment_10yr_results, "Fixed Enrollment 10yr")
fixed_enrollment_30yr_df = aggregate_results_to_dataframe(fixed_enrollment_30yr_results, "Fixed Enrollment 30yr")
fixed_enrollment_static_10yr_df = aggregate_results_to_dataframe(fixed_enrollment_static_10yr_results, "Fixed Enrollment Static 10yr")

print("Aggregating Fixed Retrospective results...")
fixed_retrospective_10yr_df = aggregate_results_to_dataframe(fixed_retrospective_10yr_results, "Fixed Retrospective 10yr")
fixed_retrospective_30yr_df = aggregate_results_to_dataframe(fixed_retrospective_30yr_results, "Fixed Retrospective 30yr")
fixed_retrospective_static_10yr_df = aggregate_results_to_dataframe(fixed_retrospective_static_10yr_results, "Fixed Retrospective Static 10yr")

# Save individual DataFrames
output_dir = '/Users/sarahurbut/aladynoulli2/pyScripts/new_oct_revision/new_notebooks/'
print(f"\nSaving aggregated results to {output_dir}...")

fixed_enrollment_10yr_df.to_csv(f'{output_dir}pooled_fixed_enrollment_10yr.csv')
fixed_enrollment_30yr_df.to_csv(f'{output_dir}pooled_fixed_enrollment_30yr.csv')
fixed_enrollment_static_10yr_df.to_csv(f'{output_dir}pooled_fixed_enrollment_static_10yr.csv')

fixed_retrospective_10yr_df.to_csv(f'{output_dir}pooled_fixed_retrospective_10yr.csv')
fixed_retrospective_30yr_df.to_csv(f'{output_dir}pooled_fixed_retrospective_30yr.csv')
fixed_retrospective_static_10yr_df.to_csv(f'{output_dir}pooled_fixed_retrospective_static_10yr.csv')

print("✓ Saved individual result files")

# Create a combined comparison DataFrame (similar to comparison_all_approaches format)
print("\nCreating combined comparison DataFrame...")
all_diseases = set()
for df in [fixed_enrollment_10yr_df, fixed_enrollment_30yr_df, fixed_retrospective_10yr_df, 
           fixed_retrospective_30yr_df, fixed_enrollment_static_10yr_df, fixed_retrospective_static_10yr_df]:
    if not df.empty:
        all_diseases.update(df.index)

comparison_df = pd.DataFrame(index=sorted(all_diseases))
comparison_df['Fixed_Enrollment_10yr'] = fixed_enrollment_10yr_df['AUC_median']
comparison_df['Fixed_Enrollment_30yr'] = fixed_enrollment_30yr_df['AUC_median']
comparison_df['Fixed_Enrollment_Static_10yr'] = fixed_enrollment_static_10yr_df['AUC_median']
comparison_df['Fixed_Retrospective_10yr'] = fixed_retrospective_10yr_df['AUC_median']
comparison_df['Fixed_Retrospective_30yr'] = fixed_retrospective_30yr_df['AUC_median']
comparison_df['Fixed_Retrospective_Static_10yr'] = fixed_retrospective_static_10yr_df['AUC_median']

comparison_df.to_csv(f'{output_dir}pooled_comparison_all_approaches.csv')
print("✓ Saved combined comparison file: pooled_comparison_all_approaches.csv")

# Print summary
print(f"\n{'='*80}")
print("SUMMARY OF AGGREGATED RESULTS")
print(f"{'='*80}")
print(f"\nFixed Enrollment - 10yr: {len(fixed_enrollment_10yr_df)} diseases")
print(f"Fixed Enrollment - 30yr: {len(fixed_enrollment_30yr_df)} diseases")
print(f"Fixed Enrollment - Static 10yr: {len(fixed_enrollment_static_10yr_df)} diseases")
print(f"Fixed Retrospective - 10yr: {len(fixed_retrospective_10yr_df)} diseases")
print(f"Fixed Retrospective - 30yr: {len(fixed_retrospective_30yr_df)} diseases")
print(f"Fixed Retrospective - Static 10yr: {len(fixed_retrospective_static_10yr_df)} diseases")

print(f"\n{'='*80}")
print("TOP 10 DISEASES BY AUC (Fixed Enrollment 10yr)")
print(f"{'='*80}")
if not fixed_enrollment_10yr_df.empty:
    print(fixed_enrollment_10yr_df[['AUC_median', 'CI_lower_median', 'CI_upper_median', 'N_Batches']].head(10).round(4))

print(f"\n{'='*80}")
print("TOP 10 DISEASES BY AUC (Fixed Retrospective 10yr)")
print(f"{'='*80}")
if not fixed_retrospective_10yr_df.empty:
    print(fixed_retrospective_10yr_df[['AUC_median', 'CI_lower_median', 'CI_upper_median', 'N_Batches']].head(10).round(4))

print(f"\n{'='*80}")
print("All results saved successfully!")
print(f"{'='*80}")

In [None]:
len(fixed_retrospective_10yr_results)

In [None]:
# First 10 batches (0-9)
first_10_enrollment_10yr = fixed_enrollment_10yr_results[:10]
first_10_retrospective_10yr = fixed_retrospective_10yr_results[:10]

# Last 30 batches (10-39)
last_30_enrollment_10yr = fixed_enrollment_10yr_results[10:]
last_30_retrospective_10yr = fixed_retrospective_10yr_results[10:]

# Use your existing compute_aggregated_cis function
first_10_enroll_agg = compute_aggregated_cis(first_10_enrollment_10yr)
last_30_enroll_agg = compute_aggregated_cis(last_30_enrollment_10yr)

first_10_retro_agg = compute_aggregated_cis(first_10_retrospective_10yr)
last_30_retro_agg = compute_aggregated_cis(last_30_retrospective_10yr)
# Calculate differences
# Calculate differences
comparison = pd.DataFrame({
    'First_10_Batches': first_10_enroll_agg['median_auc'],
    'Last_30_Batches': last_30_enroll_agg['median_auc'],
})
comparison['Difference'] = comparison['First_10_Batches'] - comparison['Last_30_Batches']

# Summary statistics
print(f"Mean difference (First 10 - Last 30): {comparison['Difference'].mean():.4f}")
print(f"Median difference: {comparison['Difference'].median():.4f}")
print(f"Std difference: {comparison['Difference'].std():.4f}")
print(f"\nDiseases where First 10 > Last 30: {(comparison['Difference'] > 0).sum()} / {len(comparison)}")
print(f"Diseases where Last 30 > First 10: {(comparison['Difference'] < 0).sum()} / {len(comparison)}")
print(f"Diseases with difference < 0.001: {(abs(comparison['Difference']) < 0.001).sum()}")

# Show top differences
print("\nTop 5 where First 10 batches are better:")
print(comparison.nlargest(5, 'Difference')[['First_10_Batches', 'Last_30_Batches', 'Difference']])

print("\nTop 5 where Last 30 batches are better:")
print(comparison.nsmallest(5, 'Difference')[['First_10_Batches', 'Last_30_Batches', 'Difference']])

In [None]:
# Aggregate new results and compare with existing CSVs
print("="*80)
print("AGGREGATING NEW RESULTS (10 batches) AND COMPARING WITH EXISTING")
print("="*80)

# Use the existing compute_aggregated_cis function (should be defined earlier in notebook)
# If not, define it here
def compute_aggregated_cis(results_list, name=""):
    """Extract CI bounds and aggregate across batches"""
    if not results_list:
        return pd.DataFrame()
    
    # Get all disease names (excluding metadata keys)
    disease_names_list = [k for k in results_list[0].keys() if k not in ['batch_idx', 'analysis_type']]
    
    ci_data = {disease: {'ci_lowers': [], 'ci_uppers': [], 'aucs': []} 
               for disease in disease_names_list}
    
    # Collect all CIs and AUCs across batches
    for result in results_list:
        for disease in disease_names_list:
            if disease in result and isinstance(result[disease], dict):
                if 'ci_lower' in result[disease] and not np.isnan(result[disease]['ci_lower']):
                    ci_data[disease]['ci_lowers'].append(result[disease]['ci_lower'])
                if 'ci_upper' in result[disease] and not np.isnan(result[disease]['ci_upper']):
                    ci_data[disease]['ci_uppers'].append(result[disease]['ci_upper'])
                if 'auc' in result[disease] and not np.isnan(result[disease]['auc']):
                    ci_data[disease]['aucs'].append(result[disease]['auc'])
    
    # Aggregate: median of bounds and median AUC
    aggregated = {}
    for disease in disease_names_list:
        if ci_data[disease]['aucs']:
            aggregated[disease] = {
                'median_auc': np.median(ci_data[disease]['aucs']),
                'ci_lower_median': np.median(ci_data[disease]['ci_lowers']) if ci_data[disease]['ci_lowers'] else np.nan,
                'ci_upper_median': np.median(ci_data[disease]['ci_uppers']) if ci_data[disease]['ci_uppers'] else np.nan,
                'ci_lower_min': np.min(ci_data[disease]['ci_lowers']) if ci_data[disease]['ci_lowers'] else np.nan,
                'ci_upper_max': np.max(ci_data[disease]['ci_uppers']) if ci_data[disease]['ci_uppers'] else np.nan,
                'n_batches': len(ci_data[disease]['aucs'])
            }
        else:
            aggregated[disease] = {
                'median_auc': np.nan,
                'ci_lower_median': np.nan,
                'ci_upper_median': np.nan,
                'ci_lower_min': np.nan,
                'ci_upper_max': np.nan,
                'n_batches': 0
            }
    
    df = pd.DataFrame(aggregated).T
    df = df.sort_values('median_auc', ascending=False)
    
    return df

# Aggregate new results
fixed_enrollment_10yr_aggregated = compute_aggregated_cis(fixed_enrollment_10yr_results, "Fixed Enrollment 10yr")
fixed_enrollment_30yr_aggregated = compute_aggregated_cis(fixed_enrollment_30yr_results, "Fixed Enrollment 30yr")
fixed_retrospective_10yr_aggregated = compute_aggregated_cis(fixed_retrospective_10yr_results, "Fixed Retrospective 10yr")
fixed_retrospective_30yr_aggregated = compute_aggregated_cis(fixed_retrospective_30yr_results, "Fixed Retrospective 30yr")

print("\nFixed Enrollment (Pooled) - 10yr:")
print(fixed_enrollment_10yr_aggregated[['median_auc', 'ci_lower_median', 'ci_upper_median', 'n_batches']].head(10))

print("\nFixed Retrospective (Pooled) - 10yr:")
print(fixed_retrospective_10yr_aggregated[['median_auc', 'ci_lower_median', 'ci_upper_median', 'n_batches']].head(10))

# Save new aggregated results
fixed_enrollment_10yr_aggregated.to_csv('fixed_enrollment_pooled_10yr_aggregated_cis.csv')
fixed_enrollment_30yr_aggregated.to_csv('fixed_enrollment_pooled_30yr_aggregated_cis.csv')
fixed_retrospective_10yr_aggregated.to_csv('fixed_retrospective_pooled_10yr_aggregated_cis.csv')
fixed_retrospective_30yr_aggregated.to_csv('fixed_retrospective_pooled_30yr_aggregated_cis.csv')

print("\n✓ New aggregated results saved to CSV files")

# Load existing CSV files for comparison
joint_10yr = pd.read_csv('joint_phi_10yr_aggregated_cis.csv', index_col=0)
joint_30yr = pd.read_csv('joint_phi_30yr_aggregated_cis.csv', index_col=0)
fixed_10yr_old = pd.read_csv('fixed_phi_10yr_aggregated_cis.csv', index_col=0)  # One batch retrospective
fixed_30yr_old = pd.read_csv('fixed_phi_30yr_aggregated_cis.csv', index_col=0)  # One batch retrospective

print("\n" + "="*80)
print("⚠️  IMPORTANT: BATCH COMPARISON NOTE")
print("="*80)
print("NEW RESULTS (Fixed Enrollment/Retrospective Pooled):")
print(f"  - Based on {fixed_enrollment_10yr_aggregated['n_batches'].iloc[0]:.0f} batches (0-100k samples)")
print("\nOLD RESULTS (Joint/Fixed Retrospective Old):")
print(f"  - Joint: Based on {joint_10yr['n_batches'].iloc[0]:.0f} batches (0-400k samples)")
print(f"  - Fixed Retrospective Old: Based on {fixed_10yr_old['n_batches'].iloc[0]:.0f} batches (0-400k samples)")
print("\n⚠️  WARNING: Comparing 10 batches (new) vs 40 batches (old) - NOT directly comparable!")
print("   For fair comparison, either:")
print("   1. Wait until all 40 batches are processed for new results, OR")
print("   2. Re-aggregate old results using only first 10 batches")
print("="*80)

print("\n" + "="*80)
print("COMPREHENSIVE COMPARISON - 10 YEAR PREDICTIONS")
print("(Note: New results = 10 batches, Old results = 40 batches)")
print("="*80)

# Get common diseases
all_diseases_10yr = set(joint_10yr.index) & set(fixed_10yr_old.index) & set(fixed_enrollment_10yr_aggregated.index) & set(fixed_retrospective_10yr_aggregated.index)
diseases_sorted_10yr = sorted(list(all_diseases_10yr))

# Create comparison DataFrame
comparison_10yr = pd.DataFrame({
    'Joint_Enrollment': joint_10yr.loc[diseases_sorted_10yr, 'median_auc'],
    'Fixed_Retrospective_Old': fixed_10yr_old.loc[diseases_sorted_10yr, 'median_auc'],  # One batch
    'Fixed_Retrospective_Pooled': fixed_retrospective_10yr_aggregated.loc[diseases_sorted_10yr, 'median_auc'],  # Pooled
    'Fixed_Enrollment_Pooled': fixed_enrollment_10yr_aggregated.loc[diseases_sorted_10yr, 'median_auc'],  # Pooled
}, index=diseases_sorted_10yr)

# Calculate differences
comparison_10yr['Fixed_Enroll_vs_Joint'] = comparison_10yr['Fixed_Enrollment_Pooled'] - comparison_10yr['Joint_Enrollment']
comparison_10yr['Fixed_Enroll_vs_Fixed_Retro_Old'] = comparison_10yr['Fixed_Enrollment_Pooled'] - comparison_10yr['Fixed_Retrospective_Old']
comparison_10yr['Fixed_Retro_Pooled_vs_Old'] = comparison_10yr['Fixed_Retrospective_Pooled'] - comparison_10yr['Fixed_Retrospective_Old']

print("\nTop 15 diseases by Joint Enrollment AUC:")
print(comparison_10yr[['Joint_Enrollment', 'Fixed_Enrollment_Pooled', 'Fixed_Retrospective_Pooled', 
                       'Fixed_Retrospective_Old', 'Fixed_Enroll_vs_Joint']].head(15).round(3))

print("\n" + "="*80)
print("COMPREHENSIVE COMPARISON - 30 YEAR PREDICTIONS")
print("(Note: New results = 10 batches, Old results = 40 batches)")
print("="*80)

# Get common diseases for 30yr
all_diseases_30yr = set(joint_30yr.index) & set(fixed_30yr_old.index) & set(fixed_enrollment_30yr_aggregated.index) & set(fixed_retrospective_30yr_aggregated.index)
diseases_sorted_30yr = sorted(list(all_diseases_30yr))

# Create comparison DataFrame
comparison_30yr = pd.DataFrame({
    'Joint_Enrollment': joint_30yr.loc[diseases_sorted_30yr, 'median_auc'],
    'Fixed_Retrospective_Old': fixed_30yr_old.loc[diseases_sorted_30yr, 'median_auc'],  # One batch
    'Fixed_Retrospective_Pooled': fixed_retrospective_30yr_aggregated.loc[diseases_sorted_30yr, 'median_auc'],  # Pooled
    'Fixed_Enrollment_Pooled': fixed_enrollment_30yr_aggregated.loc[diseases_sorted_30yr, 'median_auc'],  # Pooled
}, index=diseases_sorted_30yr)

# Calculate differences
comparison_30yr['Fixed_Enroll_vs_Joint'] = comparison_30yr['Fixed_Enrollment_Pooled'] - comparison_30yr['Joint_Enrollment']
comparison_30yr['Fixed_Enroll_vs_Fixed_Retro_Old'] = comparison_30yr['Fixed_Enrollment_Pooled'] - comparison_30yr['Fixed_Retrospective_Old']
comparison_30yr['Fixed_Retro_Pooled_vs_Old'] = comparison_30yr['Fixed_Retrospective_Pooled'] - comparison_30yr['Fixed_Retrospective_Old']

print("\nTop 15 diseases by Joint Enrollment AUC:")
print(comparison_30yr[['Joint_Enrollment', 'Fixed_Enrollment_Pooled', 'Fixed_Retrospective_Pooled', 
                       'Fixed_Retrospective_Old', 'Fixed_Enroll_vs_Joint']].head(15).round(3))

# Save comparison tables
comparison_10yr.to_csv('comparison_all_approaches_10yr.csv')
comparison_30yr.to_csv('comparison_all_approaches_30yr.csv')

print("\n✓ Comparison tables saved to CSV files")

# Summary statistics - Focus on key comparisons
print("\n" + "="*80)
print("KEY COMPARISONS - 10 YEAR PREDICTIONS")
print("="*80)

print("\n" + "-"*80)
print("COMPARISON 1: Fixed Enrollment (Pooled) vs Joint Enrollment")
print("(Same enrollment data, different phi estimation: fixed vs joint)")
print("-"*80)
print(f"  Joint Enrollment:              {comparison_10yr['Joint_Enrollment'].mean():.3f}")
print(f"  Fixed Enrollment (Pooled):     {comparison_10yr['Fixed_Enrollment_Pooled'].mean():.3f}")
diff_enroll_vs_joint = comparison_10yr['Fixed_Enroll_vs_Joint'].mean()
print(f"  Mean difference:               {diff_enroll_vs_joint:+.4f} ± {comparison_10yr['Fixed_Enroll_vs_Joint'].std():.4f}")
better_count_enroll = (comparison_10yr['Fixed_Enroll_vs_Joint'] > 0).sum()
print(f"  Diseases where Fixed > Joint: {better_count_enroll} / {len(comparison_10yr)} ({better_count_enroll/len(comparison_10yr)*100:.1f}%)")
print(f"\n  → Clinical feasibility: Can fixed enrollment phi match joint performance?")

print("\n" + "-"*80)
print("COMPARISON 2: Fixed Retrospective (Pooled) vs Fixed Retrospective (Old)")
print("(Same approach, different phi source: pooled vs single batch)")
print("-"*80)
print(f"  Fixed Retrospective (Old):     {comparison_10yr['Fixed_Retrospective_Old'].mean():.3f}")
print(f"  Fixed Retrospective (Pooled):  {comparison_10yr['Fixed_Retrospective_Pooled'].mean():.3f}")
diff_retro_pooled_vs_old = comparison_10yr['Fixed_Retro_Pooled_vs_Old'].mean()
print(f"  Mean difference:               {diff_retro_pooled_vs_old:+.4f} ± {comparison_10yr['Fixed_Retro_Pooled_vs_Old'].std():.4f}")
better_count_retro = (comparison_10yr['Fixed_Retro_Pooled_vs_Old'] > 0).sum()
print(f"  Diseases where Pooled > Old:    {better_count_retro} / {len(comparison_10yr)} ({better_count_retro/len(comparison_10yr)*100:.1f}%)")
print(f"\n  → Pooling effect: Does pooling phi improve over single batch?")

print("\n" + "-"*80)
print("BONUS COMPARISON: Fixed Enrollment vs Fixed Retrospective (Old)")
print("(Different phi source: enrollment vs retrospective)")
print("-"*80)
diff_enroll_vs_retro_old = comparison_10yr['Fixed_Enroll_vs_Fixed_Retro_Old'].mean()
print(f"  Mean difference:               {diff_enroll_vs_retro_old:+.4f} ± {comparison_10yr['Fixed_Enroll_vs_Fixed_Retro_Old'].std():.4f}")
better_count_enroll_vs_retro = (comparison_10yr['Fixed_Enroll_vs_Fixed_Retro_Old'] > 0).sum()
print(f"  Diseases where Enrollment > Retro: {better_count_enroll_vs_retro} / {len(comparison_10yr)} ({better_count_enroll_vs_retro/len(comparison_10yr)*100:.1f}%)")
print(f"\n  → Enrollment-specific phi: Does enrollment phi outperform retrospective phi?")

print("\n" + "="*80)
print("KEY COMPARISONS - 30 YEAR PREDICTIONS")
print("="*80)

print("\n" + "-"*80)
print("COMPARISON 1: Fixed Enrollment (Pooled) vs Joint Enrollment")
print("(Same enrollment data, different phi estimation: fixed vs joint)")
print("-"*80)
print(f"  Joint Enrollment:              {comparison_30yr['Joint_Enrollment'].mean():.3f}")
print(f"  Fixed Enrollment (Pooled):     {comparison_30yr['Fixed_Enrollment_Pooled'].mean():.3f}")
diff_enroll_vs_joint_30yr = comparison_30yr['Fixed_Enroll_vs_Joint'].mean()
print(f"  Mean difference:               {diff_enroll_vs_joint_30yr:+.4f} ± {comparison_30yr['Fixed_Enroll_vs_Joint'].std():.4f}")
better_count_enroll_30yr = (comparison_30yr['Fixed_Enroll_vs_Joint'] > 0).sum()
print(f"  Diseases where Fixed > Joint: {better_count_enroll_30yr} / {len(comparison_30yr)} ({better_count_enroll_30yr/len(comparison_30yr)*100:.1f}%)")
print(f"\n  → Clinical feasibility: Can fixed enrollment phi match joint performance?")

print("\n" + "-"*80)
print("COMPARISON 2: Fixed Retrospective (Pooled) vs Fixed Retrospective (Old)")
print("(Same approach, different phi source: pooled vs single batch)")
print("-"*80)
print(f"  Fixed Retrospective (Old):     {comparison_30yr['Fixed_Retrospective_Old'].mean():.3f}")
print(f"  Fixed Retrospective (Pooled):  {comparison_30yr['Fixed_Retrospective_Pooled'].mean():.3f}")
diff_retro_pooled_vs_old_30yr = comparison_30yr['Fixed_Retro_Pooled_vs_Old'].mean()
print(f"  Mean difference:               {diff_retro_pooled_vs_old_30yr:+.4f} ± {comparison_30yr['Fixed_Retro_Pooled_vs_Old'].std():.4f}")
better_count_retro_30yr = (comparison_30yr['Fixed_Retro_Pooled_vs_Old'] > 0).sum()
print(f"  Diseases where Pooled > Old:    {better_count_retro_30yr} / {len(comparison_30yr)} ({better_count_retro_30yr/len(comparison_30yr)*100:.1f}%)")
print(f"\n  → Pooling effect: Does pooling phi improve over single batch?")

print("\n" + "-"*80)
print("BONUS COMPARISON: Fixed Enrollment vs Fixed Retrospective (Old)")
print("(Different phi source: enrollment vs retrospective)")
print("-"*80)
diff_enroll_vs_retro_old_30yr = comparison_30yr['Fixed_Enroll_vs_Fixed_Retro_Old'].mean()
print(f"  Mean difference:               {diff_enroll_vs_retro_old_30yr:+.4f} ± {comparison_30yr['Fixed_Enroll_vs_Fixed_Retro_Old'].std():.4f}")
better_count_enroll_vs_retro_30yr = (comparison_30yr['Fixed_Enroll_vs_Fixed_Retro_Old'] > 0).sum()
print(f"  Diseases where Enrollment > Retro: {better_count_enroll_vs_retro_30yr} / {len(comparison_30yr)} ({better_count_enroll_vs_retro_30yr/len(comparison_30yr)*100:.1f}%)")
print(f"\n  → Enrollment-specific phi: Does enrollment phi outperform retrospective phi?")


In [None]:
import torch
import numpy as np

# Load master checkpoints
enrollment_master = torch.load("/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/master_for_fitting_pooled_enrollment_data.pt", weights_only=False)
retrospective_master = torch.load("/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/master_for_fitting_pooled_all_data.pt", weights_only=False)


In [None]:
enrollment_master['model_state_dict']['psi']

In [None]:
initial_psi=torch.load("/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/initial_psi_400k.pt")
initial_psi

In [None]:
import torch
import numpy as np

# Load master checkpoints
enrollment_master = torch.load("/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/master_for_fitting_pooled_enrollment_data.pt", weights_only=False)
retrospective_master = torch.load("/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/master_for_fitting_pooled_all_data.pt", weights_only=False)

enrollment_master_phi = enrollment_master['model_state_dict']['phi']
enrollment_master_psi = enrollment_master['model_state_dict']['psi']
retrospective_master_phi = retrospective_master['model_state_dict']['phi']
retrospective_master_psi = retrospective_master['model_state_dict']['psi']

print("="*80)
print("VERIFYING FIXED ENROLLMENT POOLED BATCHES")
print("="*80)

# Check enrollment fixed phi batches
enrollment_batch_starts = [0, 10000, 20000, 30000, 40000, 50000, 60000, 70000, 80000, 90000]  # Adjust as needed
for start_idx in enrollment_batch_starts:
    end_idx = start_idx + 10000
    batch_path = f'/Users/sarahurbut/Library/CloudStorage/Dropbox/enrollment_predictions_fixedphi_ENROLLMENT_pooled/model_enroll_fixedphi_sex_{start_idx}_{end_idx}.pt'
    try:
        batch_ckpt = torch.load(batch_path, weights_only=False)
        batch_phi = batch_ckpt['model_state_dict']['phi']
        batch_psi = batch_ckpt['model_state_dict']['psi']
        
        phi_match = torch.allclose(batch_phi, enrollment_master_phi, atol=1e-6)
        psi_match = torch.allclose(batch_psi, enrollment_master_psi, atol=1e-6)
        
        print(f"\nBatch {start_idx}-{end_idx}:")
        print(f"  Phi matches master: {phi_match}")
        print(f"  Psi matches master: {psi_match}")
        if not phi_match:
            diff = (batch_phi - enrollment_master_phi).abs().max()
            print(f"  Max phi difference: {diff:.10f}")
        if not psi_match:
            diff = (batch_psi - enrollment_master_psi).abs().max()
            print(f"  Max psi difference: {diff:.10f}")
    except Exception as e:
        print(f"\nBatch {start_idx}-{end_idx}: ERROR - {e}")

print("\n" + "="*80)
print("VERIFYING FIXED RETROSPECTIVE POOLED BATCHES")
print("="*80)


In [None]:

# Check retrospective fixed phi batches
retrospective_batch_starts = [0, 10000, 20000, 30000, 40000, 50000, 60000, 70000, 80000, 90000]  # Adjust as needed
for start_idx in retrospective_batch_starts:
    end_idx = start_idx + 10000
    batch_path = f'/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/enrollment_predictions_fixedphi_RETROSPECTIVE_pooled/model_enroll_fixedphi_sex_{start_idx}_{end_idx}.pt'
    try:
        batch_ckpt = torch.load(batch_path, weights_only=False)
        batch_phi = batch_ckpt['model_state_dict']['phi']
        batch_psi = batch_ckpt['model_state_dict']['psi']
        
        phi_match = torch.allclose(batch_phi, retrospective_master_phi, atol=1e-6)
        psi_match = torch.allclose(batch_psi, retrospective_master_psi, atol=1e-6)
        
        print(f"\nBatch {start_idx}-{end_idx}:")
        print(f"  Phi matches master: {phi_match}")
        print(f"  Psi matches master: {psi_match}")
        if not phi_match:
            diff = (batch_phi - retrospective_master_phi).abs().max()
            print(f"  Max phi difference: {diff:.10f}")
        if not psi_match:
            diff = (batch_psi - retrospective_master_psi).abs().max()
            print(f"  Max psi difference: {diff:.10f}")
    except Exception as e:
        print(f"\nBatch {start_idx}-{end_idx}: ERROR - {e}")

print("\n" + "="*80)
print("VERIFICATION COMPLETE")
print("="*80)

# now washout 

In [None]:
# Load full pce_df
pce_df_full = pd.read_csv('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/pce_prevent_full.csv')
disease_names = essentials['disease_names']

# Storage for results - Dynamic predictions
joint_10yr_results_washout = []
joint_30yr_results_washout = []
fixed_retrospective_10yr_results_washout = []
fixed_retrospective_30yr_results_washout = []
fixed_enrollment_10yr_results_washout = []
fixed_enrollment_30yr_results_washout = []

# Storage for results - Static predictions (1-year score for 10-year outcome)
joint_static_10yr_results_washout = []
fixed_retrospective_static_10yr_results_washout = []
fixed_enrollment_static_10yr_results_washout = []

# Load full tensors once (shared across all analyses)
if 'Y_full' not in globals():
    Y_full = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/Y_tensor.pt')
if 'E_full' not in globals():
    E_full = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/E_enrollment_full.pt')

# Loop through checkpoints 0-40 (batch_0_10000 to batch_390000_400000)
for batch_idx in range(40):  # 40 batches (0-39)
    start_idx = batch_idx * 10000
    end_idx = (batch_idx + 1) * 10000
    
    print(f"\n{'='*80}")
    print(f"Processing batch {batch_idx}: {start_idx} to {end_idx}")
    print(f"{'='*80}")
    
    # Get pce_df subset for this batch
    pce_df_subset = pce_df_full[start_idx:end_idx].copy().reset_index(drop=True)
    
    # Extract batch from full tensors (shared for all analyses)
    Y_batch = Y_full[start_idx:end_idx]
    E_batch = E_full[start_idx:end_idx]
    
    # === JOINT PHI CHECKPOINTS ===
    joint_ckpt_path = f'/Users/sarahurbut/Library/CloudStorage/Dropbox/enrollment_prediction_jointphi_sex_pcs/enrollment_model_W0.0001_batch_{start_idx}_{end_idx}.pt'
    
    try:
        joint_ckpt = torch.load(joint_ckpt_path, weights_only=False)
        model.load_state_dict(joint_ckpt['model_state_dict'])
        
        # Use Y from checkpoint and update model.Y so forward() uses correct patients
        Y_batch_joint = joint_ckpt['Y']
        model.Y = torch.tensor(Y_batch_joint, dtype=torch.float32)
        model.N = Y_batch_joint.shape[0]
        
        # 10-year predictions
        print(f"\nJoint Phi - 10 year predictions...")
        joint_10yr = evaluate_major_diseases_wsex_with_bootstrap_dynamic_withwashout(
            model, Y_batch_joint, E_100k, disease_names, pce_df_subset, washout_years=1,
            n_bootstraps=100, follow_up_duration_years=10, patient_indices=None
        )
        joint_10yr['batch_idx'] = batch_idx
        joint_10yr_results_washout.append(joint_10yr)
        
        # 30-year predictions
        print(f"\nJoint Phi - 30 year predictions...")
        joint_30yr = evaluate_major_diseases_wsex_with_bootstrap_dynamic_withwashout(
            model, Y_batch_joint, E_100k, disease_names, pce_df_subset, washout_years=1,
            n_bootstraps=100, follow_up_duration_years=30, patient_indices=None
        )
        joint_30yr['batch_idx'] = batch_idx
        joint_30yr_results_washout.append(joint_30yr)
        
        # Static 10-year predictions
        print(f"\nJoint Phi - Static 10 year predictions...")
        joint_static_10yr = evaluate_major_diseases_wsex_with_bootstrap_withwashout(
            model=model,
            Y_100k=Y_batch_joint,
            E_100k=E_100k,
            disease_names=disease_names,
            pce_df=pce_df_subset,
            washout_years=1,
            n_bootstraps=100,
            follow_up_duration_years=10,
        )
        joint_static_10yr['batch_idx'] = batch_idx
        joint_static_10yr_results_washout.append(joint_static_10yr)
        
    except FileNotFoundError:
        print(f"Joint phi checkpoint not found: {joint_ckpt_path}")
    except Exception as e:
        print(f"Error processing joint phi checkpoint {batch_idx}: {e}")
        import traceback
        traceback.print_exc()
    
    # === FIXED PHI FROM RETROSPECTIVE POOLED ===
    fixed_retrospective_ckpt_path = f'/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/enrollment_predictions_fixedphi_RETROSPECTIVE_pooled/model_enroll_fixedphi_sex_{start_idx}_{end_idx}.pt'
    
    try:
        print(f"\n--- Fixed Phi (RETROSPECTIVE Pooled) ---")
        fixed_ckpt = torch.load(fixed_retrospective_ckpt_path, weights_only=False)
        model.load_state_dict(fixed_ckpt['model_state_dict'])
        
        # Update model.Y and model.N so forward() uses correct patients
        model.Y = torch.tensor(Y_batch, dtype=torch.float32)
        model.N = Y_batch.shape[0]
       
        # 10-year predictions
        print(f"Fixed Phi (RETROSPECTIVE) - 10 year predictions...")
        fixed_10yr = evaluate_major_diseases_wsex_with_bootstrap_dynamic_withwashout(
            model, Y_batch, E_batch, disease_names, pce_df_subset, washout_years=1,
            n_bootstraps=100, follow_up_duration_years=10, patient_indices=None
        )
        fixed_10yr['batch_idx'] = batch_idx
        fixed_10yr['analysis_type'] = 'fixed_retrospective'
        fixed_retrospective_10yr_results_washout.append(fixed_10yr)
        
        # 30-year predictions
        print(f"Fixed Phi (RETROSPECTIVE) - 30 year predictions...")
        fixed_30yr = evaluate_major_diseases_wsex_with_bootstrap_dynamic_withwashout(
            model, Y_batch, E_batch, disease_names, pce_df_subset, washout_years=1,
            n_bootstraps=100, follow_up_duration_years=30, patient_indices=None
        )
        fixed_30yr['batch_idx'] = batch_idx
        fixed_30yr['analysis_type'] = 'fixed_retrospective'
        fixed_retrospective_30yr_results_washout.append(fixed_30yr)
        
        # Static 10-year predictions
        print(f"Fixed Phi (RETROSPECTIVE) - Static 10 year predictions...")
        fixed_static_10yr = evaluate_major_diseases_wsex_with_bootstrap_withwashout(
            model=model,
            Y_100k=Y_batch,
            E_100k=E_batch,
            disease_names=disease_names,
            pce_df=pce_df_subset,
            washout_years=1,
            n_bootstraps=100,
            follow_up_duration_years=10,
        )
        fixed_static_10yr['batch_idx'] = batch_idx
        fixed_static_10yr['analysis_type'] = 'fixed_retrospective'
        fixed_retrospective_static_10yr_results_washout.append(fixed_static_10yr)
        
    except FileNotFoundError:
        print(f"Fixed phi (RETROSPECTIVE) checkpoint not found: {fixed_retrospective_ckpt_path}")
    except Exception as e:
        print(f"Error processing fixed phi (RETROSPECTIVE) checkpoint {batch_idx}: {e}")
        import traceback
        traceback.print_exc()
    
    # === FIXED PHI FROM ENROLLMENT POOLED ===
    fixed_enrollment_ckpt_path = f'/Users/sarahurbut/Library/CloudStorage/Dropbox/enrollment_predictions_fixedphi_ENROLLMENT_pooled/model_enroll_fixedphi_sex_{start_idx}_{end_idx}.pt'
    
    try:
        print(f"\n--- Fixed Phi (ENROLLMENT Pooled) ---")
        fixed_ckpt = torch.load(fixed_enrollment_ckpt_path, weights_only=False)
        model.load_state_dict(fixed_ckpt['model_state_dict'])
        
        # Update model.Y and model.N so forward() uses correct patients
        model.Y = torch.tensor(Y_batch, dtype=torch.float32)
        model.N = Y_batch.shape[0]
       
        # 10-year predictions
        print(f"Fixed Phi (ENROLLMENT) - 10 year predictions...")
        fixed_10yr = evaluate_major_diseases_wsex_with_bootstrap_dynamic_withwashout(
            model, Y_batch, E_batch, disease_names, pce_df_subset, washout_years=1,
            n_bootstraps=100, follow_up_duration_years=10, patient_indices=None
        )
        fixed_10yr['batch_idx'] = batch_idx
        fixed_10yr['analysis_type'] = 'fixed_enrollment'
        fixed_enrollment_10yr_results_washout.append(fixed_10yr)
        
        # 30-year predictions
        print(f"Fixed Phi (ENROLLMENT) - 30 year predictions...")
        fixed_30yr = evaluate_major_diseases_wsex_with_bootstrap_dynamic_withwashout(
            model, Y_batch, E_batch, disease_names, pce_df_subset, washout_years=1,
            n_bootstraps=100, follow_up_duration_years=30, patient_indices=None
        )
        fixed_30yr['batch_idx'] = batch_idx
        fixed_30yr['analysis_type'] = 'fixed_enrollment'
        fixed_enrollment_30yr_results_washout.append(fixed_30yr)
        
        # Static 10-year predictions
        print(f"Fixed Phi (ENROLLMENT) - Static 10 year predictions...")
        fixed_static_10yr = evaluate_major_diseases_wsex_with_bootstrap_withwashout(
            model=model,
            Y_100k=Y_batch,
            E_100k=E_batch,
            disease_names=disease_names,
            pce_df=pce_df_subset,
            washout_years=1,
            n_bootstraps=100,
            follow_up_duration_years=10,
        )
        fixed_static_10yr['batch_idx'] = batch_idx
        fixed_static_10yr['analysis_type'] = 'fixed_enrollment'
        fixed_enrollment_static_10yr_results_washout.append(fixed_static_10yr)
        
    except FileNotFoundError:
        print(f"Fixed phi (ENROLLMENT) checkpoint not found: {fixed_enrollment_ckpt_path}")
    except Exception as e:
        print(f"Error processing fixed phi (ENROLLMENT) checkpoint {batch_idx}: {e}")
        import traceback
        traceback.print_exc()

print(f"\n{'='*80}")
print("Completed processing all checkpoints!")
print(f"{'='*80}")
print(f"Joint - 10yr: {len(joint_10yr_results_washout)} batches")
print(f"Joint - 30yr: {len(joint_30yr_results_washout)} batches")
print(f"Fixed Retrospective - 10yr: {len(fixed_retrospective_10yr_results_washout)} batches")
print(f"Fixed Retrospective - 30yr: {len(fixed_retrospective_30yr_results_washout)} batches")
print(f"Fixed Enrollment - 10yr: {len(fixed_enrollment_10yr_results_washout)} batches")
print(f"Fixed Enrollment - 30yr: {len(fixed_enrollment_30yr_results_washout)} batches")

In [None]:
# Aggregate washout results and create comparison summaries
print("="*80)
print("AGGREGATING WASHOUT RESULTS (40 batches)")
print("="*80)

# Aggregate all washout results
joint_10yr_washout_agg = compute_aggregated_cis(joint_10yr_results_washout, "Joint 10yr Washout")
joint_30yr_washout_agg = compute_aggregated_cis(joint_30yr_results_washout, "Joint 30yr Washout")
fixed_retro_10yr_washout_agg = compute_aggregated_cis(fixed_retrospective_10yr_results_washout, "Fixed Retrospective 10yr Washout")
fixed_retro_30yr_washout_agg = compute_aggregated_cis(fixed_retrospective_30yr_results_washout, "Fixed Retrospective 30yr Washout")
fixed_enroll_10yr_washout_agg = compute_aggregated_cis(fixed_enrollment_10yr_results_washout, "Fixed Enrollment 10yr Washout")
fixed_enroll_30yr_washout_agg = compute_aggregated_cis(fixed_enrollment_30yr_results_washout, "Fixed Enrollment 30yr Washout")

# Static results
joint_static_10yr_washout_agg = compute_aggregated_cis(joint_static_10yr_results_washout, "Joint Static 10yr Washout")
fixed_retro_static_10yr_washout_agg = compute_aggregated_cis(fixed_retrospective_static_10yr_results_washout, "Fixed Retrospective Static 10yr Washout")
fixed_enroll_static_10yr_washout_agg = compute_aggregated_cis(fixed_enrollment_static_10yr_results_washout, "Fixed Enrollment Static 10yr Washout")

print(f"\nBatches processed:")
print(f"  Joint 10yr: {len(joint_10yr_results_washout)} batches")
print(f"  Joint 30yr: {len(joint_30yr_results_washout)} batches")
print(f"  Fixed Retrospective 10yr: {len(fixed_retrospective_10yr_results_washout)} batches")
print(f"  Fixed Retrospective 30yr: {len(fixed_retrospective_30yr_results_washout)} batches")
print(f"  Fixed Enrollment 10yr: {len(fixed_enrollment_10yr_results_washout)} batches")
print(f"  Fixed Enrollment 30yr: {len(fixed_enrollment_30yr_results_washout)} batches")

# Create comparison DataFrames (similar to comparison_all_approaches CSV)
print("\n" + "="*80)
print("CREATING WASHOUT COMPARISON TABLES")
print("="*80)

# Get common diseases
all_diseases_10yr_washout = set(joint_10yr_washout_agg.index) & set(fixed_retro_10yr_washout_agg.index) & set(fixed_enroll_10yr_washout_agg.index)
diseases_sorted_10yr_washout = sorted(list(all_diseases_10yr_washout))

all_diseases_30yr_washout = set(joint_30yr_washout_agg.index) & set(fixed_retro_30yr_washout_agg.index) & set(fixed_enroll_30yr_washout_agg.index)
diseases_sorted_30yr_washout = sorted(list(all_diseases_30yr_washout))

# 10-year washout comparison
comparison_10yr_washout = pd.DataFrame({
    'Joint_Enrollment': joint_10yr_washout_agg.loc[diseases_sorted_10yr_washout, 'median_auc'],
    'Fixed_Retrospective_Pooled': fixed_retro_10yr_washout_agg.loc[diseases_sorted_10yr_washout, 'median_auc'],
    'Fixed_Enrollment_Pooled': fixed_enroll_10yr_washout_agg.loc[diseases_sorted_10yr_washout, 'median_auc'],
}, index=diseases_sorted_10yr_washout)

comparison_10yr_washout['Fixed_Enroll_vs_Joint'] = comparison_10yr_washout['Fixed_Enrollment_Pooled'] - comparison_10yr_washout['Joint_Enrollment']
comparison_10yr_washout['Fixed_Retro_vs_Joint'] = comparison_10yr_washout['Fixed_Retrospective_Pooled'] - comparison_10yr_washout['Joint_Enrollment']
comparison_10yr_washout['Fixed_Retro_vs_Enroll'] = comparison_10yr_washout['Fixed_Retrospective_Pooled'] - comparison_10yr_washout['Fixed_Enrollment_Pooled']

# 30-year washout comparison
comparison_30yr_washout = pd.DataFrame({
    'Joint_Enrollment': joint_30yr_washout_agg.loc[diseases_sorted_30yr_washout, 'median_auc'],
    'Fixed_Retrospective_Pooled': fixed_retro_30yr_washout_agg.loc[diseases_sorted_30yr_washout, 'median_auc'],
    'Fixed_Enrollment_Pooled': fixed_enroll_30yr_washout_agg.loc[diseases_sorted_30yr_washout, 'median_auc'],
}, index=diseases_sorted_30yr_washout)

comparison_30yr_washout['Fixed_Enroll_vs_Joint'] = comparison_30yr_washout['Fixed_Enrollment_Pooled'] - comparison_30yr_washout['Joint_Enrollment']
comparison_30yr_washout['Fixed_Retro_vs_Joint'] = comparison_30yr_washout['Fixed_Retrospective_Pooled'] - comparison_30yr_washout['Joint_Enrollment']
comparison_30yr_washout['Fixed_Retro_vs_Enroll'] = comparison_30yr_washout['Fixed_Retrospective_Pooled'] - comparison_30yr_washout['Fixed_Enrollment_Pooled']

# Save comparison tables
comparison_10yr_washout.to_csv('comparison_all_approaches_10yr_washout.csv')
comparison_30yr_washout.to_csv('comparison_all_approaches_30yr_washout.csv')

print("\n✓ Comparison tables saved:")
print("  - comparison_all_approaches_10yr_washout.csv")
print("  - comparison_all_approaches_30yr_washout.csv")

# Summary statistics
print("\n" + "="*80)
print("WASHOUT RESULTS SUMMARY - 10 YEAR PREDICTIONS")
print("="*80)

print(f"\nMean AUC across all diseases:")
print(f"  Joint Enrollment:              {comparison_10yr_washout['Joint_Enrollment'].mean():.4f}")
print(f"  Fixed Retrospective (Pooled):  {comparison_10yr_washout['Fixed_Retrospective_Pooled'].mean():.4f}")
print(f"  Fixed Enrollment (Pooled):     {comparison_10yr_washout['Fixed_Enrollment_Pooled'].mean():.4f}")

print(f"\nFixed Retrospective vs Joint:")
retro_better_10yr = (comparison_10yr_washout['Fixed_Retro_vs_Joint'] > 0).sum()
print(f"  Retrospective better: {retro_better_10yr} / {len(comparison_10yr_washout)-1} diseases")
print(f"  Mean difference: {comparison_10yr_washout['Fixed_Retro_vs_Joint'].mean():.4f}")
print(f"  Median difference: {comparison_10yr_washout['Fixed_Retro_vs_Joint'].median():.4f}")

print(f"\nFixed Enrollment vs Joint:")
enroll_better_10yr = (comparison_10yr_washout['Fixed_Enroll_vs_Joint'] > 0).sum()
print(f"  Enrollment better: {enroll_better_10yr} / {len(comparison_10yr_washout)-1} diseases")
print(f"  Mean difference: {comparison_10yr_washout['Fixed_Enroll_vs_Joint'].mean():.4f}")
print(f"  Median difference: {comparison_10yr_washout['Fixed_Enroll_vs_Joint'].median():.4f}")

print(f"\nFixed Retrospective vs Fixed Enrollment:")
retro_vs_enroll_10yr = (comparison_10yr_washout['Fixed_Retro_vs_Enroll'] > 0).sum()
print(f"  Retrospective better: {retro_vs_enroll_10yr} / {len(comparison_10yr_washout)-1} diseases")
print(f"  Mean difference: {comparison_10yr_washout['Fixed_Retro_vs_Enroll'].mean():.4f}")
print(f"  Median difference: {comparison_10yr_washout['Fixed_Retro_vs_Enroll'].median():.4f}")

print("\n" + "="*80)
print("WASHOUT RESULTS SUMMARY - 30 YEAR PREDICTIONS")
print("="*80)

print(f"\nMean AUC across all diseases:")
print(f"  Joint Enrollment:              {comparison_30yr_washout['Joint_Enrollment'].mean():.4f}")
print(f"  Fixed Retrospective (Pooled):  {comparison_30yr_washout['Fixed_Retrospective_Pooled'].mean():.4f}")
print(f"  Fixed Enrollment (Pooled):     {comparison_30yr_washout['Fixed_Enrollment_Pooled'].mean():.4f}")

print(f"\nFixed Retrospective vs Joint:")
retro_better_30yr = (comparison_30yr_washout['Fixed_Retro_vs_Joint'] > 0).sum()
print(f"  Retrospective better: {retro_better_30yr} / {len(comparison_30yr_washout)-1} diseases")
print(f"  Mean difference: {comparison_30yr_washout['Fixed_Retro_vs_Joint'].mean():.4f}")
print(f"  Median difference: {comparison_30yr_washout['Fixed_Retro_vs_Joint'].median():.4f}")

print(f"\nFixed Enrollment vs Joint:")
enroll_better_30yr = (comparison_30yr_washout['Fixed_Enroll_vs_Joint'] > 0).sum()
print(f"  Enrollment better: {enroll_better_30yr} / {len(comparison_30yr_washout)-1} diseases")
print(f"  Mean difference: {comparison_30yr_washout['Fixed_Enroll_vs_Joint'].mean():.4f}")
print(f"  Median difference: {comparison_30yr_washout['Fixed_Enroll_vs_Joint'].median():.4f}")

print(f"\nFixed Retrospective vs Fixed Enrollment:")
retro_vs_enroll_30yr = (comparison_30yr_washout['Fixed_Retro_vs_Enroll'] > 0).sum()
print(f"  Retrospective better: {retro_vs_enroll_30yr} / {len(comparison_30yr_washout)-1} diseases")
print(f"  Mean difference: {comparison_30yr_washout['Fixed_Retro_vs_Enroll'].mean():.4f}")
print(f"  Median difference: {comparison_30yr_washout['Fixed_Retro_vs_Enroll'].median():.4f}")

# Top performing diseases
print("\n" + "="*80)
print("TOP 10 DISEASES - 10 YEAR WASHOUT (Fixed Retrospective Pooled)")
print("="*80)
top_10_10yr = comparison_10yr_washout.nlargest(10, 'Fixed_Retrospective_Pooled')
print(top_10_10yr[['Fixed_Retrospective_Pooled', 'Fixed_Enrollment_Pooled', 'Joint_Enrollment', 
                   'Fixed_Retro_vs_Enroll']].round(4))

print("\n" + "="*80)
print("TOP 10 DISEASES - 30 YEAR WASHOUT (Fixed Retrospective Pooled)")
print("="*80)
top_10_30yr = comparison_30yr_washout.nlargest(10, 'Fixed_Retrospective_Pooled')
print(top_10_30yr[['Fixed_Retrospective_Pooled', 'Fixed_Enrollment_Pooled', 'Joint_Enrollment', 
                   'Fixed_Retro_vs_Enroll']].round(4))

print("\n✓ Washout analysis complete!")

In [None]:
import pandas as pd
import numpy as np

# Load all comparison files
df_10yr = pd.read_csv('/Users/sarahurbut/aladynoulli2/pyScripts/new_oct_revision/new_notebooks/comparison_all_approaches_10yr.csv', index_col=0)
df_30yr = pd.read_csv('/Users/sarahurbut/aladynoulli2/pyScripts/new_oct_revision/new_notebooks/comparison_all_approaches_30yr.csv', index_col=0)
df_10yr_washout = pd.read_csv('/Users/sarahurbut/aladynoulli2/pyScripts/new_oct_revision/new_notebooks/comparison_all_approaches_10yr_washout.csv', index_col=0)
df_30yr_washout = pd.read_csv('/Users/sarahurbut/aladynoulli2/pyScripts/new_oct_revision/new_notebooks/comparison_all_approaches_30yr_washout.csv', index_col=0)

# Compare Fixed_Retrospective_Pooled: washout vs no washout
print("="*80)
print("WASHOUT EFFECT: Comparing Fixed_Retrospective_Pooled (No Washout vs 1-Year Washout)")
print("="*80)

# Get common diseases (convert to sorted list)
common_diseases_10yr = sorted(list(set(df_10yr.index) & set(df_10yr_washout.index)))
common_diseases_30yr = sorted(list(set(df_30yr.index) & set(df_30yr_washout.index)))

# 10-year comparison
print("\n" + "="*80)
print("10-YEAR PREDICTIONS: No Washout vs 1-Year Washout")
print("="*80)

washout_effect_10yr = pd.DataFrame({
    'No_Washout': df_10yr.loc[common_diseases_10yr, 'Fixed_Retrospective_Pooled'],
    'With_Washout': df_10yr_washout.loc[common_diseases_10yr, 'Fixed_Retrospective_Pooled'],
}, index=common_diseases_10yr)

washout_effect_10yr['Difference'] = washout_effect_10yr['With_Washout'] - washout_effect_10yr['No_Washout']
washout_effect_10yr['Percent_Loss'] = (washout_effect_10yr['Difference'] / washout_effect_10yr['No_Washout']) * 100

print(f"\nMean AUC:")
print(f"  No Washout:    {washout_effect_10yr['No_Washout'].mean():.4f}")
print(f"  With Washout: {washout_effect_10yr['With_Washout'].mean():.4f}")
print(f"  Mean difference: {washout_effect_10yr['Difference'].mean():.4f}")
print(f"  Mean % loss: {washout_effect_10yr['Percent_Loss'].mean():.2f}%")

print(f"\nSummary:")
print(f"  Diseases with <0.01 AUC loss: {(washout_effect_10yr['Difference'] > -0.01).sum()} / {len(washout_effect_10yr)}")
print(f"  Diseases with <0.02 AUC loss: {(washout_effect_10yr['Difference'] > -0.02).sum()} / {len(washout_effect_10yr)}")
print(f"  Diseases with <0.05 AUC loss: {(washout_effect_10yr['Difference'] > -0.05).sum()} / {len(washout_effect_10yr)}")
print(f"  Mean absolute difference: {washout_effect_10yr['Difference'].abs().mean():.4f}")
print(f"  Median absolute difference: {washout_effect_10yr['Difference'].abs().median():.4f}")

print(f"\nTop 5 largest losses:")
print(washout_effect_10yr.nsmallest(5, 'Difference')[['No_Washout', 'With_Washout', 'Difference', 'Percent_Loss']].round(4))

# 30-year comparison
print("\n" + "="*80)
print("30-YEAR PREDICTIONS: No Washout vs 1-Year Washout")
print("="*80)

washout_effect_30yr = pd.DataFrame({
    'No_Washout': df_30yr.loc[common_diseases_30yr, 'Fixed_Retrospective_Pooled'],
    'With_Washout': df_30yr_washout.loc[common_diseases_30yr, 'Fixed_Retrospective_Pooled'],
}, index=common_diseases_30yr)

washout_effect_30yr['Difference'] = washout_effect_30yr['With_Washout'] - washout_effect_30yr['No_Washout']
washout_effect_30yr['Percent_Loss'] = (washout_effect_30yr['Difference'] / washout_effect_30yr['No_Washout']) * 100

print(f"\nMean AUC:")
print(f"  No Washout:    {washout_effect_30yr['No_Washout'].mean():.4f}")
print(f"  With Washout: {washout_effect_30yr['With_Washout'].mean():.4f}")
print(f"  Mean difference: {washout_effect_30yr['Difference'].mean():.4f}")
print(f"  Mean % loss: {washout_effect_30yr['Percent_Loss'].mean():.2f}%")

print(f"\nSummary:")
print(f"  Diseases with <0.01 AUC loss: {(washout_effect_30yr['Difference'] > -0.01).sum()} / {len(washout_effect_30yr)}")
print(f"  Diseases with <0.02 AUC loss: {(washout_effect_30yr['Difference'] > -0.02).sum()} / {len(washout_effect_30yr)}")
print(f"  Diseases with <0.05 AUC loss: {(washout_effect_30yr['Difference'] > -0.05).sum()} / {len(washout_effect_30yr)}")
print(f"  Mean absolute difference: {washout_effect_30yr['Difference'].abs().mean():.4f}")
print(f"  Median absolute difference: {washout_effect_30yr['Difference'].abs().median():.4f}")

print(f"\nTop 5 largest losses:")
print(washout_effect_30yr.nsmallest(5, 'Difference')[['No_Washout', 'With_Washout', 'Difference', 'Percent_Loss']].round(4))

# Save comparison
washout_effect_10yr.to_csv('washout_effect_10yr.csv')
washout_effect_30yr.to_csv('washout_effect_30yr.csv')

print("\n" + "="*80)
print("CONCLUSION:")
print("="*80)
print("If mean/median differences are <0.01-0.02, washout has minimal impact!")
print("This suggests the model is robust and not overfitting to prevalent cases.")

In [None]:
import pandas as pd
import numpy as np

# Load your results
df_10yr = pd.read_csv('/Users/sarahurbut/aladynoulli2/pyScripts/new_oct_revision/new_notebooks/comparison_all_approaches_10yr.csv', index_col=0)
df_30yr = pd.read_csv('/Users/sarahurbut/aladynoulli2/pyScripts/new_oct_revision/new_notebooks/comparison_all_approaches_30yr.csv', index_col=0)

# Published benchmarks (from literature search)
# Format: {disease: {'10yr': (min, max, typical), '30yr': (min, max, typical), 'reference': 'study type'}}
published_benchmarks = {
    'ASCVD': {
        '10yr': (0.65, 0.80, 0.72),  # Typical CVD prediction models
        '30yr': (0.65, 0.80, 0.72),  # Long-term CVD risk
        'reference': 'CVD risk models (Framingham, PCE, etc.)'
    },
    'Diabetes': {
        '10yr': (0.60, 0.75, 0.67),  # Type 2 diabetes prediction
        '30yr': (0.64, 0.80, 0.70),  # Long-term diabetes complications
        'reference': 'Diabetes prediction models'
    },
    'Heart_Failure': {
        '10yr': (0.70, 0.85, 0.78),  # Heart failure prediction
        '30yr': (0.82, 0.85, 0.83),  # Long-term HF (published: 0.82-0.85)
        'reference': 'Heart failure long-term prediction (C-statistic 0.82-0.85)'
    },
    'Bladder_Cancer': {
        '10yr': (0.65, 0.88, 0.76),  # Bladder cancer (published: C-index 0.876)
        '30yr': (0.60, 0.88, 0.74),  # Long-term bladder cancer
        'reference': 'Bladder cancer mortality (C-index 0.876)'
    },
    'CKD': {
        '10yr': (0.67, 0.71, 0.69),  # CKD prediction (published: 0.667-0.694)
        '30yr': (0.58, 0.70, 0.64),  # Long-term CKD
        'reference': 'CKD prediction models (C-statistic 0.667-0.694)'
    },
    'Prostate_Cancer': {
        '10yr': (0.65, 0.90, 0.78),  # Prostate cancer (published: C-index 0.869, AUC 0.847-0.904)
        '30yr': (0.64, 0.90, 0.77),  # Long-term prostate cancer
        'reference': 'Prostate cancer nomogram (C-index 0.869, AUC 0.847-0.904)'
    },
    'Stroke': {
        '10yr': (0.65, 0.75, 0.70),  # Stroke prediction
        '30yr': (0.58, 0.75, 0.66),  # Long-term stroke
        'reference': 'Stroke risk prediction models'
    },
    'Lung_Cancer': {
        '10yr': (0.60, 0.75, 0.68),  # Lung cancer prediction
        '30yr': (0.63, 0.75, 0.69),  # Long-term lung cancer
        'reference': 'Lung cancer prediction models'
    },
    'Colorectal_Cancer': {
        '10yr': (0.60, 0.75, 0.68),  # Colorectal cancer
        '30yr': (0.58, 0.75, 0.67),  # Long-term colorectal cancer
        'reference': 'Colorectal cancer prediction models'
    },
    'Parkinsons': {
        '10yr': (0.70, 0.80, 0.75),  # Parkinson's prediction
        '30yr': (0.66, 0.80, 0.73),  # Long-term Parkinson's
        'reference': 'Parkinson\'s disease prediction models'
    }
}

# Create comparison table
comparison_data = []

for disease in published_benchmarks.keys():
    if disease in df_10yr.index and disease in df_30yr.index:
        your_10yr = df_10yr.loc[disease, 'Fixed_Retrospective_Pooled']
        your_30yr = df_30yr.loc[disease, 'Fixed_Retrospective_Pooled']
        
        pub_10yr_range = published_benchmarks[disease]['10yr']
        pub_30yr_range = published_benchmarks[disease]['30yr']
        
        comparison_data.append({
            'Disease': disease,
            'Your_10yr_AUC': your_10yr,
            'Published_10yr_Range': f"{pub_10yr_range[0]:.2f}-{pub_10yr_range[1]:.2f}",
            'Published_10yr_Typical': pub_10yr_range[2],
            '10yr_vs_Published': your_10yr - pub_10yr_range[2],
            'Your_30yr_AUC': your_30yr,
            'Published_30yr_Range': f"{pub_30yr_range[0]:.2f}-{pub_30yr_range[1]:.2f}",
            'Published_30yr_Typical': pub_30yr_range[2],
            '30yr_vs_Published': your_30yr - pub_30yr_range[2],
            'Reference': published_benchmarks[disease]['reference']
        })

comparison_df = pd.DataFrame(comparison_data)

print("="*120)
print("COMPARISON: Your Fixed_Retrospective_Pooled Results vs Published Benchmarks")
print("="*120)
print("\n10-YEAR PREDICTIONS:")
print("-"*120)
print(comparison_df[['Disease', 'Your_10yr_AUC', 'Published_10yr_Range', 'Published_10yr_Typical', 
                     '10yr_vs_Published', 'Reference']].to_string(index=False))

print("\n\n30-YEAR PREDICTIONS:")
print("-"*120)
print(comparison_df[['Disease', 'Your_30yr_AUC', 'Published_30yr_Range', 'Published_30yr_Typical', 
                     '30yr_vs_Published', 'Reference']].to_string(index=False))

# Summary statistics
print("\n\n" + "="*120)
print("SUMMARY STATISTICS")
print("="*120)
print(f"\n10-Year Predictions:")
print(f"  Mean difference (Your - Published): {comparison_df['10yr_vs_Published'].mean():.4f}")
print(f"  Diseases where you exceed published: {(comparison_df['10yr_vs_Published'] > 0).sum()} / {len(comparison_df)}")
print(f"  Diseases within published range: {((comparison_df['Your_10yr_AUC'] >= comparison_df['Published_10yr_Typical'] - 0.05) & (comparison_df['Your_10yr_AUC'] <= comparison_df['Published_10yr_Typical'] + 0.05)).sum()} / {len(comparison_df)}")

print(f"\n30-Year Predictions:")
print(f"  Mean difference (Your - Published): {comparison_df['30yr_vs_Published'].mean():.4f}")
print(f"  Diseases where you exceed published: {(comparison_df['30yr_vs_Published'] > 0).sum()} / {len(comparison_df)}")
print(f"  Diseases within published range: {((comparison_df['Your_30yr_AUC'] >= comparison_df['Published_30yr_Typical'] - 0.05) & (comparison_df['Your_30yr_AUC'] <= comparison_df['Published_30yr_Typical'] + 0.05)).sum()} / {len(comparison_df)}")

# Highlight exceptional performance
print("\n\n" + "="*120)
print("EXCEPTIONAL PERFORMANCE (>0.05 above published typical)")
print("="*120)
exceptional_10yr = comparison_df[comparison_df['10yr_vs_Published'] > 0.05]
exceptional_30yr = comparison_df[comparison_df['30yr_vs_Published'] > 0.05]

if len(exceptional_10yr) > 0:
    print("\n10-Year:")
    print(exceptional_10yr[['Disease', 'Your_10yr_AUC', 'Published_10yr_Typical', '10yr_vs_Published']].to_string(index=False))
else:
    print("\n10-Year: None")

if len(exceptional_30yr) > 0:
    print("\n30-Year:")
    print(exceptional_30yr[['Disease', 'Your_30yr_AUC', 'Published_30yr_Typical', '30yr_vs_Published']].to_string(index=False))
else:
    print("\n30-Year: None")