# R3: Competing Risks - Patients Can Develop Multiple Diseases

## Reviewer Question

**Referee #3**: "It's unclear to me whether and how well the model actually accounts for competing risks, such as death, emigration, and other strong competitors. This can also be caused by diagnostic hierarchy. What scares me are the reported hazards (e.g. figures S6-8), which seem to decrease for very old individuals, which can be interpreted as decreased risks. This looks like a competing risk issue."

## Our Response

The reviewer's primary concern about decreasing hazards at older ages is addressed in **R3_Q4_Decreasing_Hazards_Censoring_Bias.ipynb** (this was censoring bias, not competing risks).

Regarding competing risks: Traditional competing risk models assume events are **mutually exclusive** - if you develop one disease, you're censored and can't develop others. However, in reality, patients often develop **multiple serious diseases over time**. 

**Aladynoulli's multi-disease approach naturally handles this** by modeling all 348 diseases simultaneously. Patients remain at risk for all diseases even after developing one. 

This notebook demonstrates this by finding patients who develop multiple serious outcomes (e.g., Myocardial Infarction, Lung Cancer, Colorectal Cancer, Breast Cancer) and showing that their predicted risks are elevated over the corrected population prevalence baseline for all diseases.


## Key Point: Multiple Diseases Are Not Mutually Exclusive

Unlike traditional competing risk models, Aladynoulli:
- **Models all 348 diseases simultaneously** - patients can develop multiple diseases
- **No censoring after first disease** - patients remain at risk for all diseases
- **Handles real-world complexity** - patients often develop multiple serious conditions

The examples below show patients who develop multiple serious outcomes (MI, Lung Cancer, Colorectal Cancer, Breast Cancer) and demonstrate that their predicted risks are elevated over the corrected population prevalence baseline for all diseases.


## What This Notebook Shows

1. **Load predictions from joint model** - Using pi predictions from the full-mode model trained on all data
2. **Find patients with multiple serious outcomes** - MI, Lung Cancer, Colorectal Cancer, Breast Cancer
3. **Compare to corrected prevalence baseline** - Show that patient risks are elevated over population baseline
4. **Visualize risk trajectories** - Plot how predicted risks compare to baseline for each disease


# ============================================================================
# Load Data: Predictions from Joint Model
# ============================================================================

import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

# Load pi predictions from joint model (full-mode, all data)
pi_predictions = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/censor_e_batchrun_vectorized/pi_fullmode_400k.pt', 
                           map_location='cpu', weights_only=False)
print(f"✓ Loaded pi predictions: {pi_predictions.shape}")

# Load E matrix (corrected) to find actual disease events
E_batch = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/E_matrix_corrected.pt',
                     map_location='cpu', weights_only=False)
print(f"✓ Loaded E matrix: {E_batch.shape}")

# Load corrected prevalence for baseline
prevalence_t = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/prevalence_t_corrected.pt',
                          map_location='cpu', weights_only=False)
print(f"✓ Loaded corrected prevalence: {prevalence_t.shape}")

# Load disease names
disease_names = pd.read_csv('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/disease_names.csv').iloc[:, 1].tolist()
print(f"✓ Loaded {len(disease_names)} disease names")


In [None]:
%run /Users/sarahurbut/aladynoulli2/pyScripts/dec_6_revision/new_notebooks/pythonscripts/compute_pi_from_fullmode_models.py --output_dir /Users/sarahurbut/Library/CloudStorage/Dropbox/censor_e_batchrun_vectorized/


Processing batch 0: 0 to 10000
Loading model from: /Users/sarahurbut/Library/CloudStorage/Dropbox/censor_e_batchrun_vectorized/enrollment_model_W0.0001_batch_0_10000.pt
  Lambda shape: torch.Size([10000, 21, 52])
  Phi shape: torch.Size([21, 348, 52])
  Kappa: 2.88468861579895

Computing pi predictions...
  Pi shape: torch.Size([10000, 348, 52])
  Pi range: [0.000000, 0.400859]
✓ Saved batch pi to: /Users/sarahurbut/Library/CloudStorage/Dropbox/censor_e_batchrun_vectorized/pi_fullmode_batch_0_10000.pt

Processing batch 1: 10000 to 20000
Loading model from: /Users/sarahurbut/Library/CloudStorage/Dropbox/censor_e_batchrun_vectorized/enrollment_model_W0.0001_batch_10000_20000.pt
  Lambda shape: torch.Size([10000, 21, 52])
  Phi shape: torch.Size([21, 348, 52])
  Kappa: 2.906322717666626

Computing pi predictions...
  Pi shape: torch.Size([10000, 348, 52])
  Pi range: [0.000000, 0.286479]
✓ Saved batch pi to: /Users/sarahurbut/Library/CloudStorage/Dropbox/censor_e_batchrun_vectorized/pi_f

In [1]:
# ============================================================================
# Find Patients with Multiple Serious Outcomes
# ============================================================================

print(f"\n{'='*80}")
print(f"FINDING PATIENTS WITH MULTIPLE SERIOUS OUTCOMES")
print(f"{'='*80}")

patients_with_multiple = []

for patient_idx in range(len(E_batch)):
    event_times = E_batch[patient_idx]
    
    # Check which diseases this patient has
    patient_diseases = {}
    for disease_name, d_idx in disease_indices.items():
        if event_times[d_idx] < 51:  # Disease occurred before age 81
            patient_diseases[disease_name] = {
                'index': d_idx,
                'age': event_times[d_idx].item() + 30  # Convert to actual age
            }
    
    # Keep patients with at least 2 of the target diseases
    if len(patient_diseases) >= 2:
        # Evaluate at earliest diagnosis timepoint
        earliest_age = min([d['age'] for d in patient_diseases.values()])
        eval_timepoint = int(earliest_age - 30)  # Convert back to timepoint
        
        # Get patient predictions at this timepoint
        pi_at_time = pi_predictions[patient_idx, :, eval_timepoint]
        
        # Calculate risk ratios for each disease
        risk_ratios = {}
        for disease_name, disease_info in patient_diseases.items():
            d_idx = disease_info['index']
            pred_risk = pi_at_time[d_idx].item()
            baseline_risk = prevalence_t[d_idx, eval_timepoint].item()
            rr = pred_risk / baseline_risk if baseline_risk > 0 else 0
            risk_ratios[disease_name] = rr
        
        # Keep if at least 2 diseases have elevated risk (RR > 1.0)
        elevated_count = sum(1 for rr in risk_ratios.values() if rr > 1.0)
        if elevated_count >= 2:
            patients_with_multiple.append({
                'patient_idx': patient_idx,
                'diseases': patient_diseases,
                'risk_ratios': risk_ratios,
                'eval_timepoint': eval_timepoint,
                'eval_age': earliest_age,
                'n_diseases': len(patient_diseases),
                'n_elevated': elevated_count
            })

print(f"✓ Found {len(patients_with_multiple)} patients with 2+ diseases and elevated risk")

# Sort by number of diseases and combined risk ratio
patients_with_multiple.sort(key=lambda x: (x['n_diseases'], sum(x['risk_ratios'].values())), reverse=True)

# Show top examples
print(f"\n{'='*80}")
print(f"TOP 10 EXAMPLES")
print(f"{'='*80}")
print(f"{'Patient':<10} {'Diseases':<50} {'Ages':<30} {'Risk Ratios':<40}")
print("-"*130)

for i, patient in enumerate(patients_with_multiple[:10]):
    diseases_str = ', '.join([d for d in patient['diseases'].keys()])
    ages_str = ', '.join([f"{int(d['age'])}" for d in patient['diseases'].values()])
    rrs_str = ', '.join([f"{d}: {patient['risk_ratios'][d]:.2f}x" for d in patient['diseases'].keys()])
    print(f"{patient['patient_idx']:<10} {diseases_str[:48]:<50} {ages_str[:28]:<30} {rrs_str[:38]:<40}")

# Select top 3 examples for plotting
selected_patients = patients_with_multiple[:3]
print(f"\n✓ Selected {len(selected_patients)} patients for detailed visualization")



FINDING PATIENTS WITH MULTIPLE SERIOUS OUTCOMES


NameError: name 'E_batch' is not defined

In [None]:
# ============================================================================
# Plot Risk Trajectories: Patient Risk vs Population Baseline
# ============================================================================

print(f"\n{'='*80}")
print(f"CREATING RISK TRAJECTORY PLOTS")
print(f"{'='*80}")

# Set style
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (16, 12)
plt.rcParams['font.size'] = 10

# Create subplots: one row per patient, one column per disease they have
n_patients = len(selected_patients)
max_diseases = max([p['n_diseases'] for p in selected_patients])

fig, axes = plt.subplots(n_patients, max_diseases, figsize=(5*max_diseases, 4*n_patients))
if n_patients == 1:
    axes = axes.reshape(1, -1)
if max_diseases == 1:
    axes = axes.reshape(-1, 1)

for patient_row, patient in enumerate(selected_patients):
    patient_idx = patient['patient_idx']
    diseases = list(patient['diseases'].keys())
    
    for disease_col, disease_name in enumerate(diseases):
        ax = axes[patient_row, disease_col]
        d_idx = patient['diseases'][disease_name]['index']
        dx_age = patient['diseases'][disease_name]['age']
        
        # Get risk trajectories
        ages = np.arange(30, 81)  # Ages 30-80
        timepoints = np.arange(0, 51)  # Timepoints 0-50
        
        patient_risks = pi_predictions[patient_idx, d_idx, :].numpy()
        baseline_risks = prevalence_t[d_idx, :].numpy()
        
        # Plot
        ax.plot(ages, patient_risks, 'r-', linewidth=2.5, label='Patient Risk', alpha=0.8)
        ax.plot(ages, baseline_risks, 'b--', linewidth=2, label='Population Baseline', alpha=0.7)
        
        # Add vertical line at diagnosis
        ax.axvline(x=dx_age, color='purple', linestyle=':', linewidth=2, 
                   label=f'Diagnosis (Age {int(dx_age)})', alpha=0.8)
        
        # Shade regions
        ax.axvspan(30, dx_age, alpha=0.1, color='gray', label='Before Diagnosis')
        ax.axvspan(dx_age, 80, alpha=0.1, color='lightcoral', label='After Diagnosis')
        
        # Add risk ratio annotation
        eval_tp = patient['eval_timepoint']
        rr = patient['risk_ratios'][disease_name]
        ax.annotate(f'RR at age {int(patient["eval_age"])}: {rr:.2f}x',
                   xy=(patient['eval_age'], patient_risks[eval_tp]),
                   xytext=(10, 10), textcoords='offset points',
                   bbox=dict(boxstyle='round,pad=0.5', facecolor='yellow', alpha=0.7),
                   fontsize=9, fontweight='bold')
        
        # Labels
        if patient_row == 0:
            ax.set_title(f'{disease_name}\nPatient {patient_idx}', fontsize=12, fontweight='bold')
        else:
            ax.set_title(f'{disease_name}', fontsize=11)
        
        if disease_col == 0:
            ax.set_ylabel('Disease Risk', fontsize=11, fontweight='bold')
        
        if patient_row == n_patients - 1:
            ax.set_xlabel('Age', fontsize=11, fontweight='bold')
        
        ax.legend(loc='upper left', fontsize=9)
        ax.grid(alpha=0.3)
        ax.set_ylim(bottom=0)

plt.suptitle('Risk Trajectories: Patients with Multiple Serious Outcomes\n'
             'Patient Risk vs Corrected Population Prevalence Baseline',
             fontsize=14, fontweight='bold', y=0.995)
plt.tight_layout()
plt.show()

print(f"\n✓ Created plots for {len(selected_patients)} patients")
print(f"  Each patient has 2+ serious diseases with elevated risk over baseline")


In [None]:
# ============================================================================
# SIMPLIFIED: Find patients with BOTH heart disease AND cancer
# Check their predictions are elevated over corrected prevalence baseline
# ============================================================================
pi_predictions = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/censor_e_batchrun_vectorized/pi_fullmode_400k.pt')
E_batch = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/E_matrix_corrected.pt')
prevalence_t = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/prevalence_t_corrected.pt')
disease_names = pd.read_csv('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/disease_names.csv').iloc[:, 1].tolist()


In [None]:

if 'pi_predictions' in locals() and 'E_batch' in locals() and 'disease_names' in locals():
    # Load corrected prevalence for baseline
    prevalence_path = Path("/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/prevalence_t_corrected.pt")
    if prevalence_path.exists():
        prevalence_t = torch.load(str(prevalence_path), map_location='cpu', weights_only=False)
        print(f"✓ Loaded corrected prevalence: {prevalence_t.shape}")
    else:
        print(f"⚠️  Corrected prevalence not found")
        prevalence_t = None
    
    # Find patients with BOTH heart disease AND lung cancer
    print(f"\n{'='*80}")
    print(f"FINDING PATIENTS WITH BOTH HEART DISEASE AND CANCER")
    print(f"{'='*80}")
    
    patients_with_both = []
    
    for patient_idx in range(len(E_batch)):
        event_times = E_batch[patient_idx]
        
        # Check if patient has heart disease
        has_heart = False
        heart_idx_found = None
        heart_age_found = None
        for hd_idx in heart_disease_indices:
            if event_times[hd_idx] < 51:
                has_heart = True
                heart_idx_found = hd_idx
                heart_age_found = event_times[hd_idx].item()
                break
        
        # Check if patient has lung cancer
        has_cancer = False
        cancer_age_found = None
        if lung_cancer_idx is not None and event_times[lung_cancer_idx] < 51:
            has_cancer = True
            cancer_age_found = event_times[lung_cancer_idx].item()
        
        if has_heart and has_cancer:
            # Evaluate at earlier diagnosis timepoint
            eval_timepoint = int(min(heart_age_found, cancer_age_found))
            pi_at_time = pi_predictions[patient_idx, :, eval_timepoint]
            
            # Get baseline from corrected prevalence
            if prevalence_t is not None:
                heart_baseline = prevalence_t[heart_idx_found, eval_timepoint].item()
                lung_baseline = prevalence_t[lung_cancer_idx, eval_timepoint].item()
            else:
                heart_baseline = pi_predictions[:, heart_idx_found, eval_timepoint].mean().item()
                lung_baseline = pi_predictions[:, lung_cancer_idx, eval_timepoint].mean().item()
            
            # Patient risks
            heart_pred = pi_at_time[heart_idx_found].item()
            lung_pred = pi_at_time[lung_cancer_idx].item()
            
            # Risk ratios
            heart_rr = heart_pred / heart_baseline if heart_baseline > 0 else 0
            lung_rr = lung_pred / lung_baseline if lung_baseline > 0 else 0
            
            # Keep if both are elevated
            if heart_rr > 1.0 and lung_rr > 1.0:
                patients_with_both.append((
                    patient_idx, heart_idx_found, heart_age_found, 
                    lung_cancer_idx, cancer_age_found,
                    heart_rr, lung_rr, eval_timepoint
                ))
    
    print(f"✓ Found {len(patients_with_both)} patients with BOTH diseases and elevated risk")
    
    if len(patients_with_both) > 0:
        # Sort by combined risk ratio
        patients_with_both.sort(key=lambda x: x[5] * x[6], reverse=True)
        
        # Show top 5 examples
        print(f"\n{'='*80}")
        print(f"TOP 5 EXAMPLES: Patients with BOTH Heart Disease and Cancer")
        print(f"{'='*80}")
        print(f"{'Patient':<10} {'Heart Disease':<30} {'Age':<6} {'RR':<8} {'Lung Cancer':<30} {'Age':<6} {'RR':<8}")
        print("-"*80)
        
        for i, (p_idx, h_idx, h_age, l_idx, l_age, h_rr, l_rr, eval_t) in enumerate(patients_with_both[:5]):
            print(f"{p_idx:<10} {disease_names[h_idx][:28]:<30} {int(h_age+30):<6} {h_rr:<8.2f} "
                  f"{disease_names[l_idx][:28]:<30} {int(l_age+30):<6} {l_rr:<8.2f}")
        
        # Use best example
        patient_idx, heart_idx, heart_age, lung_cancer_idx, lung_cancer_age, heart_rr, lung_rr, eval_timepoint = patients_with_both[0]
        
        print(f"\n{'='*80}")
        print(f"SELECTED EXAMPLE: Patient {patient_idx}")
        print(f"{'='*80}")
        print(f"  Heart Disease: {disease_names[heart_idx]} at age {int(heart_age + 30)}")
        print(f"  Lung Cancer: {disease_names[lung_cancer_idx]} at age {int(lung_cancer_age + 30)}")
        print(f"  Risk Ratios (at age {eval_timepoint + 30}):")
        print(f"    Heart Disease RR: {heart_rr:.2f}x")
        print(f"    Lung Cancer RR: {lung_rr:.2f}x")
        print(f"\n  ✓ This patient has BOTH diseases with elevated risk for both")
        print(f"  ✓ Demonstrates that competing risks are not mutually exclusive")
    else:
        print(f"\n⚠️  No patients found with both diseases and elevated risk")
        patient_idx = None
else:
    print("⚠️  Please run previous cells first to load pi_predictions, E_batch, and disease_names")
    patient_idx = None


In [None]:
%run /Users/sarahurbut/aladynoulli2/pyScripts/dec_6_revision/new_notebooks/pythonscripts/compare_pi_calculations.py

In [None]:
# ============================================================================
# FIND EXAMPLES FOR MULTIPLE DISEASE PAIRS WITH HIGH RISK RATIOS
# ============================================================================
"""
Find patients with specific disease progressions and high relative risk ratios:
1. ASCVD → Lung Cancer
2. ASCVD → Colon Cancer
3. Breast Cancer → ASCVD

We prioritize patients with risk ratios > 2.5x to show dramatic examples.
"""

import torch
from pathlib import Path
import pandas as pd
import numpy as np

print("="*80)
print("FINDING EXAMPLES: MULTIPLE DISEASE PAIRS WITH HIGH RISK RATIOS")
print("="*80)
print("\nSearching for:")
print("  1. ASCVD → Lung Cancer")
print("  2. ASCVD → Colon Cancer")
print("  3. Breast Cancer → ASCVD")
print("\nPrioritizing patients with risk ratios > 2.5x")

# Load and compute pi_full from model checkpoints (corrected E training)
# Formula: pi = kappa * einsum('nkt,kdt->ndt', softmax(lambda), sigmoid(phi))
import sys
sys.path.insert(0, '/Users/sarahurbut/aladynoulli2/pyScripts')
from utils import calculate_pi_pred

batch_dir = Path('/Users/sarahurbut/Library/CloudStorage/Dropbox/censor_e_batchrun_vectorized/')
pi_full_path = batch_dir / 'pi_fullmode_400k.pt'

# Try to load pre-computed FULL file first
if pi_full_path.exists():
    print(f"\nLoading pre-computed FULL pi from: {pi_full_path}")
    pi_predictions = torch.load(str(pi_full_path), weights_only=False)
    print(f"✓ Loaded full pi predictions: {pi_predictions.shape}")
    use_full_dataset = True
else:
    # Compute pi from model checkpoints
    print(f"\n⚠️  FULL file not found. Computing pi from model checkpoints...")
    print(f"Model directory: {batch_dir}")
    
    # Find all model files and sort by start index
    import re
    def extract_start_idx(filename):
        match = re.search(r'(\d+)_(\d+)\.pt$', filename.name)
        return int(match.group(1)) if match else 0
    
    # Look for model files (model_enroll_fixedphi_sex_*_*.pt)
    model_files = list(batch_dir.glob('model_enroll_fixedphi_sex_*_*.pt'))
    model_files = [f for f in model_files if 'FULL' not in f.name]
    model_files = sorted(model_files, key=extract_start_idx)
    
    if not model_files:
        print(f"⚠️  No model files found in {batch_dir}")
        use_full_dataset = False
    else:
        print(f"Found {len(model_files)} model files")
        pi_batches = []
        
        for i, model_file in enumerate(model_files):
            print(f"  Processing batch {i+1}/{len(model_files)}: {model_file.name}")
            
            # Load model checkpoint
            checkpoint = torch.load(model_file, map_location='cpu', weights_only=False)
            
            # Extract model state dict
            if 'model_state_dict' in checkpoint:
                state_dict = checkpoint['model_state_dict']
            else:
                state_dict = checkpoint
            
            # Extract parameters
            lambda_ = state_dict['lambda_'].cpu()  # [N, K, T]
            phi = state_dict['phi'].cpu()  # [K, D, T]
            kappa = state_dict.get('kappa', torch.tensor(1.0))
            if torch.is_tensor(kappa):
                kappa = kappa.cpu()
                if kappa.numel() == 1:
                    kappa = kappa.item()
                else:
                    kappa = kappa.mean().item()
            
            # Compute pi using utils function: pi = kappa * einsum('nkt,kdt->ndt', softmax(lambda), sigmoid(phi))
            pi_batch = calculate_pi_pred(lambda_, phi, kappa)
            
            # Clamp to avoid numerical issues
            epsilon = 1e-8
            pi_batch = torch.clamp(pi_batch, epsilon, 1 - epsilon)
            
            print(f"    Lambda shape: {lambda_.shape}, Phi shape: {phi.shape}, Kappa: {kappa}")
            print(f"    Computed pi shape: {pi_batch.shape}")
            pi_batches.append(pi_batch)
        
        # Concatenate all batches
        print("\nConcatenating computed pi batches...")
        pi_predictions = torch.cat(pi_batches, dim=0)
        print(f"✓ Computed and assembled full pi predictions: {pi_predictions.shape}")
        use_full_dataset = True

# Load disease names
disease_names_path = Path("/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/disease_names.csv")
if disease_names_path.exists():
    disease_names_df = pd.read_csv(disease_names_path)
    # Disease names are in column 1 (the "x" column), not column 0
    # Column 0 is the row number/ID
    # pandas.read_csv uses first row as column names, so iloc[:, 1] gives us the disease names
    disease_names = disease_names_df.iloc[:, 1].tolist()
    # Remove header value "x" if it's the first element
    if len(disease_names) > 0 and str(disease_names[0]).lower() == 'x':
        disease_names = disease_names[1:]
    # Convert all disease names to strings (they might be integers or have NaN)
    disease_names = [str(name) if pd.notna(name) else f"Disease_{i}" for i, name in enumerate(disease_names)]
    print(f"✓ Loaded {len(disease_names)} disease names")
    print(f"  First few: {disease_names[:5]}")
else:
    disease_names = [f"Disease_{i}" for i in range(pi_predictions.shape[1])]
    print("⚠️  Using placeholder disease names")

# Load Y and E to find patients with Heart Disease → colon Cancer progression
Y_path = Path("/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/Y_tensor.pt")
E_path = Path("/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/E_matrix.pt")

if Y_path.exists() and E_path.exists():
    Y_full = torch.load(str(Y_path), weights_only=False)
    E_full = torch.load(str(E_path), weights_only=False)
    
    if use_full_dataset:
        # Use full dataset
        Y_batch = Y_full
        E_batch = E_full
        print(f"✓ Using full dataset: {len(Y_batch)} patients")
    else:
        # Subset to batch 0-10000 (matching pi predictions)
        Y_batch = Y_full[0:10000]
        E_batch = E_full[0:10000]
        print(f"✓ Using subset: {len(Y_batch)} patients")
    
    # Find disease indices - ASCVD and Lung Cancer
    heart_disease_indices = []  # ASCVD (MI, CAD, ischemic heart disease)
    lung_cancer_idx = None  # lung cancer specifically
    
    # Heart disease terms (ASCVD)
    heart_disease_terms = [
        'myocardial infarction', 'coronary', 'ischemic heart', 'angina', 
        'coronary atherosclerosis', 'acute ischemic', 'chronic ischemic'
    ]
    
    # Lung cancer terms
    lung_cancer_terms = [
        'lung', 'bronchus', 'bronchial'
    ]
    
    for i, name in enumerate(disease_names):
        name_str = str(name).lower()
        
        # Check for heart disease (ASCVD)
        for term in heart_disease_terms:
            if term in name_str and i not in heart_disease_indices:
                heart_disease_indices.append(i)
                break
        
        # Check for lung cancer
        for term in lung_cancer_terms:
            if term in name_str:
                # Exclude skin cancers
                if 'skin' not in name_str and 'melanoma' not in name_str:
                    lung_cancer_idx = i
                    break    

    print(f"\nDisease indices:")
    if len(heart_disease_indices) > 0:
        print(f"  ASCVD found: {len(heart_disease_indices)}")
        print(f"  Sample ASCVD diseases:")
        for idx in heart_disease_indices[:5]:
            print(f"    - index {idx}: {disease_names[idx]}")
    
    if lung_cancer_idx is not None:
        print(f"  Lung cancer: index {lung_cancer_idx} ({disease_names[lung_cancer_idx]})")
    else:
        print(f"  ⚠️  Lung cancer not found")
    
    if len(heart_disease_indices) > 0 and lung_cancer_idx is not None:
        # Search for patients with ASCVD → Lung Cancer progression
        print(f"\nSearching for patients with ASCVD → Lung Cancer progression...")
        good_examples = []
        
        for patient_idx_cand in range(len(E_batch)):
            event_times = E_batch[patient_idx_cand]
            
            # Check if patient has ANY ASCVD
            heart_disease_ages = []
            for hd_idx in heart_disease_indices:
                if event_times[hd_idx] < 51:
                    heart_age = event_times[hd_idx].item()
                    heart_disease_ages.append((hd_idx, heart_age))
            
            if len(heart_disease_ages) == 0:
                continue
            
            # Get earliest ASCVD
            earliest_heart_idx, earliest_heart_age = min(heart_disease_ages, key=lambda x: x[1])
            
            # Check if patient has lung cancer
            if lung_cancer_idx is not None and event_times[lung_cancer_idx] < 51:
                lung_cancer_age = event_times[lung_cancer_idx].item()
                
                # Check if ASCVD occurs first
                if earliest_heart_age <= lung_cancer_age:
                    good_examples.append((patient_idx_cand, earliest_heart_idx, earliest_heart_age, lung_cancer_age))
        
        print(f"✓ Found {len(good_examples)} patients with ASCVD → Lung Cancer progression")
        
        # Find patient with elevated risk for BOTH diseases (using at-risk population baseline)
        print(f"\nSearching for patient with elevated risk for BOTH ASCVD and Lung Cancer...")
        print(f"  Using at-risk population baseline (corrected E filtering)")
        
        best_patient = None
        best_combined_score = -1
        
        # Evaluate patients with ASCVD → Lung Cancer progression
        for patient_idx_cand, heart_idx_cand, heart_age_cand, lung_cancer_age_cand in good_examples[:500]:  # Check first 500
            t_heart_cand = int(heart_age_cand)
            
            # Get patient predictions at heart disease diagnosis
            pi_at_heart_cand = pi_predictions[patient_idx_cand, :, t_heart_cand]
            
            # Calculate at-risk population baseline for each disease at this timepoint
            # Only include patients still at risk (E_corrected >= timepoint)
            at_risk_mask_heart = E_corrected_batch[:, heart_idx_cand] >= t_heart_cand
            at_risk_mask_lung = E_corrected_batch[:, lung_cancer_idx] >= t_heart_cand
            
            if at_risk_mask_heart.sum() > 0 and at_risk_mask_lung.sum() > 0:
                # At-risk population baselines
                heart_pop_at_risk = pi_predictions[at_risk_mask_heart, heart_idx_cand, t_heart_cand].mean().item()
                lung_pop_at_risk = pi_predictions[at_risk_mask_lung, lung_cancer_idx, t_heart_cand].mean().item()
                
                # Patient risks
                heart_pred = pi_at_heart_cand[heart_idx_cand].item()
                lung_pred = pi_at_heart_cand[lung_cancer_idx].item()
                
                # Risk ratios (using at-risk baseline)
                heart_rr = heart_pred / heart_pop_at_risk if heart_pop_at_risk > 0 else 0
                lung_rr = lung_pred / lung_pop_at_risk if lung_pop_at_risk > 0 else 0
                
                # Score: prefer patients with elevated risk for BOTH diseases
                # Combined score = product of risk ratios (both need to be > 1.0)
                if heart_rr > 1.0 and lung_rr > 1.0:
                    combined_score = heart_rr * lung_rr
                    if combined_score > best_combined_score:
                        best_combined_score = combined_score
                        best_patient = (patient_idx_cand, heart_idx_cand, heart_age_cand, 
                                       lung_cancer_age_cand, heart_rr, lung_rr)
        
        # Use best patient found, or fall back to patient 23941 for comparison
        target_patient_idx = 23941
        
        if best_patient is not None and best_combined_score > 1.5:
            # Use best patient with elevated risk for both
            patient_idx, heart_idx, heart_age, lung_cancer_age, heart_rr_found, lung_rr_found = best_patient
            print(f"\n✓ Selected Patient {patient_idx} with elevated risk for BOTH diseases:")
            print(f"  ASCVD RR: {heart_rr_found:.2f}x, Lung Cancer RR: {lung_rr_found:.2f}x")
        elif target_patient_idx < len(E_batch):
            # Fall back to patient 23941 for comparison
            event_times_target = E_batch[target_patient_idx]
            heart_disease_ages_target = []
            for hd_idx in heart_disease_indices:
                if event_times_target[hd_idx] < 51:
                    heart_age_target = event_times_target[hd_idx].item()
                    heart_disease_ages_target.append((hd_idx, heart_age_target))
            
            lung_cancer_age_target = None
            if lung_cancer_idx is not None and event_times_target[lung_cancer_idx] < 51:
                lung_cancer_age_target = event_times_target[lung_cancer_idx].item()
            
            if len(heart_disease_ages_target) > 0:
                earliest_heart_idx_target, earliest_heart_age_target = min(heart_disease_ages_target, key=lambda x: x[1])
                patient_idx = target_patient_idx
                heart_idx = earliest_heart_idx_target
                heart_age = earliest_heart_age_target
                lung_cancer_age = lung_cancer_age_target
                print(f"\n✓ Using Patient {patient_idx} (for comparison with old analysis)")
            else:
                if len(good_examples) > 0:
                    patient_idx, heart_idx, heart_age, lung_cancer_age = good_examples[0]
                    print(f"  Using Patient {patient_idx} from search")
                else:
                    patient_idx = None
        else:
            if len(good_examples) > 0:
                patient_idx, heart_idx, heart_age, lung_cancer_age = good_examples[0]
                print(f"  Using Patient {patient_idx} from search")
            else:
                patient_idx = None
         
        # Get event times for this patient (if we have a valid patient)
        if patient_idx is not None:
            event_times = E_batch[patient_idx]
            
            print(f"\nExample Patient: Patient {patient_idx}")
            print(f"  Disease progression:")
            print(f"    1. ASCVD ({disease_names[heart_idx]}) at age {heart_age + 30}")
            if lung_cancer_age is not None:
                print(f"    2. Lung Cancer ({disease_names[lung_cancer_idx]}) at age {lung_cancer_age + 30}")
            print(f"  Total diseases: {(E_batch[patient_idx] < 51).sum().item()}")
            
            # Calculate risk ratios
            t_heart = int(heart_age)
            pi_at_heart = pi_predictions[patient_idx, :, t_heart]
            population_baseline = pi_predictions[:, :, t_heart].mean(dim=0)
            
            lung_cancer_pred_final = pi_at_heart[lung_cancer_idx].item()
            lung_cancer_pop_final = population_baseline[lung_cancer_idx].item()
            lung_cancer_rr_final = lung_cancer_pred_final / lung_cancer_pop_final if lung_cancer_pop_final > 0 else 0
            
            print(f"\n✓ Patient {patient_idx} (Lung Cancer RR={lung_cancer_rr_final:.2f}x)")
            print(f"\nCalculating population baseline risks...")
            
            # Update for display
            earliest_cancer_idx = lung_cancer_idx
            cancer_age = lung_cancer_age
            mi_age = heart_age  # For compatibility with rest of code
            
            # Only continue if we have a valid patient
            if patient_idx is not None:
                # Find top predicted subsequent diseases (excluding heart diseases)
                other_diseases = [i for i in range(len(disease_names)) if i not in heart_disease_indices]
                pi_other = pi_at_heart[other_diseases]
                top_indices = torch.argsort(pi_other, descending=True)[:10]
                top_diseases = [other_diseases[i] for i in top_indices]
                
                print(f"\nTop 10 Predicted Subsequent Diseases (at Heart Disease diagnosis):")
                print("  Disease                          Predicted  Population  Risk Ratio")
                print("  " + "-"*70)
                for d_idx in top_diseases:
                    pred_risk = pi_at_heart[d_idx].item()
                    pop_risk = population_baseline[d_idx].item()
                    risk_ratio = pred_risk / pop_risk if pop_risk > 0 else float('inf')
                    marker = " ⭐" if d_idx == lung_cancer_idx and risk_ratio > 1.2 else ""
                    print(f"  {disease_names[d_idx][:30]:30s} {pred_risk:.4f}    {pop_risk:.4f}     {risk_ratio:.2f}x{marker}")
                
                print(f"\n  Risk Ratio = Predicted Risk / Population Risk")
                print(f"  Values > 1.0 indicate elevated risk relative to population average")
                print(f"  ⭐ = Lung Cancer with elevated risk")
                
                # Check what actually happened
                print(f"\nActual Subsequent Diseases (after Heart Disease diagnosis):")
                subsequent_diseases = []
                for d_idx in range(len(disease_names)):
                    if d_idx not in heart_disease_indices and E_batch[patient_idx, d_idx] < 51:
                        subsequent_age = E_batch[patient_idx, d_idx].item() + 30
                        if subsequent_age > heart_age + 30:  # After heart disease diagnosis
                            pred_risk = pi_at_heart[d_idx].item()
                            pop_risk = population_baseline[d_idx].item()
                            risk_ratio = pred_risk / pop_risk if pop_risk > 0 else float('inf')
                            subsequent_diseases.append((d_idx, subsequent_age, pred_risk, pop_risk, risk_ratio))
                
                if len(subsequent_diseases) > 0:
                    subsequent_diseases.sort(key=lambda x: x[1])  # Sort by age
                    print("  Disease                          Age    Predicted  Population  Risk Ratio")
                    print("  " + "-"*75)
                    for d_idx, age, pred_risk, pop_risk, risk_ratio in subsequent_diseases[:15]:
                        marker = " ⭐" if d_idx == lung_cancer_idx and risk_ratio > 1.2 else ""
                        print(f"  {disease_names[d_idx][:30]:30s} {age:3.0f}   {pred_risk:.4f}    {pop_risk:.4f}     {risk_ratio:.2f}x{marker}")
                    
                    # Highlight Lung Cancer specifically
                    if lung_cancer_idx is not None:
                        print(f"\n  Key Subsequent Disease: Lung Cancer")
                        for d_idx, age, pred_risk, pop_risk, risk_ratio in subsequent_diseases:
                            if d_idx == lung_cancer_idx:
                                print(f"    {disease_names[d_idx][:40]:40s} Age {age:3.0f}  Pred: {pred_risk:.4f}  Pop: {pop_risk:.4f}  RR: {risk_ratio:.2f}x")
                    
                    # Summary statistics
                    elevated_risk = [s for s in subsequent_diseases if s[4] > 1.5]  # Risk ratio > 1.5
                    print(f"\n  Summary:")
                    print(f"  - Total subsequent diseases: {len(subsequent_diseases)}")
                    print(f"  - Diseases with elevated risk (RR > 1.5x): {len(elevated_risk)}")
                    if len(elevated_risk) > 0:
                        avg_rr = sum(s[4] for s in elevated_risk) / len(elevated_risk)
                        print(f"  - Average risk ratio for elevated diseases: {avg_rr:.2f}x")
                    
                    print(f"\n✓ Aladynoulli predicted {len([s for s in subsequent_diseases if s[2] > 0.01])} subsequent diseases (predicted risk > 0.01)")
                    print(f"✓ Patient actually developed {len(subsequent_diseases)} subsequent diseases")
                    print(f"✓ Patient developed LUNG CANCER after ASCVD, demonstrating that 'competing risks' can both occur")
                else:
                    print("  (No subsequent diseases yet)")
                
                print("\n" + "="*80)
                print("KEY INSIGHT:")
                print("="*80)
                print("Traditional competing risk models assume ASCVD and Lung Cancer are EXCLUSIVE - you die from one or the other.")
                print("But this patient developed BOTH - ASCVD first, then Lung Cancer.")
                print("Aladynoulli can predict lung cancer risk EVEN AFTER ASCVD diagnosis.")
                print("This demonstrates that 'competing risks' are not truly exclusive - patients can develop multiple serious conditions.")
                print("Aladynoulli's multi-disease approach correctly models this clinical reality.")
        else:
            print("\n⚠️  No patients found with ASCVD → Lung Cancer progression")
    else:
        print("\n⚠️  Could not find required disease indices (ASCVD, Lung Cancer)")
else:
    print("\n⚠️  Could not load Y/E tensors for patient identification")



In [None]:
# ============================================================================
# VISUALIZATION: Risk Trajectories for Heart Disease and lung Cancer
# Demonstrating Multi-Disease Prediction After Initial Diagnosis
# ============================================================================

import matplotlib.pyplot as plt
import numpy as np

if 'patient_idx' in locals() and patient_idx is not None and 'heart_idx' in locals():

    print("="*80)
    print("CREATING RISK TRAJECTORY PLOTS")
    print("="*80)
    
    # Create figure with 2 subplots
    fig, axes = plt.subplots(2, 1, figsize=(14, 10))
    
    ages = np.arange(30, 82)  # Ages 30-81
    
    # Get risk trajectories for both diseases
    heart_risk_patient = pi_predictions[patient_idx, heart_idx, :].numpy()
    heart_risk_pop = pi_predictions[:, heart_idx, :].mean(dim=0).numpy()
    
    lung_cancer_risk_patient = pi_predictions[patient_idx, lung_cancer_idx, :].numpy()
    lung_cancer_risk_pop = pi_predictions[:, lung_cancer_idx, :].mean(dim=0).numpy()
    
    # Calculate risk ratios
    heart_rr = heart_risk_patient / (heart_risk_pop + 1e-10)
    lung_cancer_rr = lung_cancer_risk_patient / (lung_cancer_risk_pop + 1e-10)
    
    # Get diagnosis ages
    heart_dx_age = heart_age + 30
    cancer_dx_age = lung_cancer_age + 30 if lung_cancer_age is not None else None
    
    # Calculate risk ratio at heart disease diagnosis
    t_heart = int(heart_age)
    rr_at_heart_dx = lung_cancer_rr[t_heart]
    
    # ===== PLOT 1: Absolute Predicted Risk =====
    ax1 = axes[0]
    
    # Plot patient risks
    ax1.plot(ages, heart_risk_patient, 'r-', linewidth=2.5, label=f'Patient Risk: {disease_names[heart_idx]}')
    ax1.plot(ages, lung_cancer_risk_patient, 'b-', linewidth=2.5, label=f'Patient Risk: {disease_names[lung_cancer_idx]}')
    
    # Plot population baselines
    ax1.plot(ages, heart_risk_pop, 'r--', linewidth=2, alpha=0.7, label=f'Population Baseline: {disease_names[heart_idx]}')
    ax1.plot(ages, lung_cancer_risk_pop, 'b--', linewidth=2, alpha=0.7, label=f'Population Baseline: {disease_names[lung_cancer_idx]}')
    
    # Add vertical line at heart disease diagnosis
    ax1.axvline(x=heart_dx_age, color='purple', linestyle=':', linewidth=2.5, 
                label=f'Heart Disease Diagnosis (Age {int(heart_dx_age)})')
    
    # Add shaded regions
    ax1.axvspan(30, heart_dx_age, alpha=0.1, color='gray', label='Before Heart Disease')
    if cancer_dx_age is not None:
        ax1.axvspan(heart_dx_age, cancer_dx_age, alpha=0.1, color='lightblue', label='After Heart Disease, Before Cancer')
        ax1.axvspan(cancer_dx_age, 80, alpha=0.1, color='lavender', label='After Cancer Diagnosis')
    else:
        ax1.axvspan(heart_dx_age, 80, alpha=0.1, color='lightblue', label='After Heart Disease')
    
    # Add annotation for risk ratio at heart disease diagnosis
    annotation_text = f'lung Cancer RR: {rr_at_heart_dx:.2f}x at Heart Disease Dx'
    ax1.annotate(annotation_text, 
                xy=(heart_dx_age, lung_cancer_risk_patient[t_heart]),
                xytext=(heart_dx_age + 5, lung_cancer_risk_patient[t_heart] + 0.0005),
                fontsize=11, fontweight='bold', color='blue',
                bbox=dict(boxstyle='round,pad=0.5', facecolor='lightblue', alpha=0.7),
                arrowprops=dict(arrowstyle='->', color='blue', lw=2))
    
    ax1.set_xlabel('Age (years)', fontsize=13, fontweight='bold')
    ax1.set_ylabel('Predicted Disease Risk', fontsize=13, fontweight='bold')
    ax1.set_title(f'Patient {patient_idx}: Risk Trajectories for Heart Disease and lung Cancer\nDemonstrating Multi-Disease Prediction After Initial Diagnosis', 
                  fontsize=14, fontweight='bold', pad=15)
    ax1.legend(loc='upper left', fontsize=10, framealpha=0.9)
    ax1.grid(True, alpha=0.3)
    ax1.set_ylim(bottom=0)
    ax1.set_xlim(30, 80)
    
    # ===== PLOT 2: Risk Ratio Trajectories =====
    ax2 = axes[1]
    
    # Plot risk ratios
    ax2.plot(ages, heart_rr, 'r-', linewidth=2.5, label=f'{disease_names[heart_idx]} Risk Ratio')
    ax2.plot(ages, lung_cancer_rr, 'b-', linewidth=2.5, label=f'{disease_names[lung_cancer_idx]} Risk Ratio')
    
    # Add horizontal line at RR=1.0
    ax2.axhline(y=1.0, color='red', linestyle='--', linewidth=1.5, alpha=0.7, label='Population Average (RR=1.0)')
    
    # Add shaded region for elevated risk
    ax2.axhspan(1.0, 3.5, alpha=0.1, color='green', label='Elevated Risk (RR > 1.0)')
    
    # Add vertical line at heart disease diagnosis
    ax2.axvline(x=heart_dx_age, color='purple', linestyle=':', linewidth=2.5, 
                label=f'Heart Disease Diagnosis (Age {int(heart_dx_age)})')
    
    # Add annotation for peak risk ratio
    ax2.annotate(f'RR = {rr_at_heart_dx:.2f}x', 
                xy=(heart_dx_age, rr_at_heart_dx),
                xytext=(heart_dx_age + 3, rr_at_heart_dx + 0.3),
                fontsize=12, fontweight='bold', color='blue',
                bbox=dict(boxstyle='round,pad=0.5', facecolor='yellow', alpha=0.8),
                arrowprops=dict(arrowstyle='->', color='blue', lw=2))
    
    ax2.set_xlabel('Age (years)', fontsize=13, fontweight='bold')
    ax2.set_ylabel('Risk Ratio (Patient / Population)', fontsize=13, fontweight='bold')
    ax2.set_title('Risk Ratio Trajectories: Patient Risk Relative to Population Average', 
                  fontsize=14, fontweight='bold', pad=15)
    ax2.legend(loc='upper left', fontsize=10, framealpha=0.9)
    ax2.grid(True, alpha=0.3)
    ax2.set_ylim(0, 5.0)
    ax2.set_xlim(30, 80)
    
    plt.tight_layout()
    plt.show()
    
    print("\n✓ Plots created successfully!")
    print(f"\nKey Findings:")
    print(f"1. Heart Disease risk peaks around age {ages[np.argmax(heart_risk_patient)]:.0f}")
    print(f"2. lung Cancer risk ratio peaks at {rr_at_heart_dx:.2f}x at heart disease diagnosis (age {int(heart_dx_age)})")
    print(f"3. Patient remains at risk for multiple diseases simultaneously")
    print(f"4. Risk trajectories demonstrate that competing risks are not mutually exclusive")
    
else:
    print("⚠️  Cannot create plots - patient data not available")
    print("   Please run Cell 4 first to identify example patient")



In [None]:
# ============================================================================
# FIND ALL THREE DISEASE PAIRS: ASCVD→Lung, ASCVD→Colon, Breast→ASCVD
# ============================================================================
"""
Find best examples for all three disease pairs with high risk ratios (>2.5x):
1. ASCVD → Lung Cancer
2. ASCVD → Colon Cancer  
3. Breast Cancer → ASCVD
"""

# Define disease search terms
def find_disease_indices(disease_names, search_terms_list, exclude_terms=None):
    """Find disease indices matching search terms"""
    indices = []
    exclude_terms = exclude_terms or []
    for i, name in enumerate(disease_names):
        name_str = str(name).lower()
        # Check if any search term matches
        if any(term.lower() in name_str for term in search_terms_list):
            # Check if should be excluded
            if not any(exclude.lower() in name_str for exclude in exclude_terms):
                indices.append(i)
    return indices

# Find all disease indices
print("\n" + "="*80)
print("FINDING DISEASE INDICES")
print("="*80)

# ASCVD diseases
ascvd_terms = [
    'myocardial infarction', 'coronary atherosclerosis', 'coronary', 
    'ischemic heart', 'angina', 'acute ischemic', 'chronic ischemic'
]
ascvd_indices = find_disease_indices(disease_names, ascvd_terms)
print(f"\nASCVD diseases: {len(ascvd_indices)} found")
for idx in ascvd_indices[:3]:
    print(f"  - {disease_names[idx]}")


# Colon cancer
colon_cancer_terms = ['colon cancer', 'rectum', 'rectosigmoid', 'anus']
colon_cancer_indices = []
for i, name in enumerate(disease_names):
    name_str = str(name).lower()
    for term in colon_cancer_terms:
        if term in name_str:
            colon_cancer_indices.append(i)
            break
print(f"\nColon cancer: {len(colon_cancer_indices)} found")
if colon_cancer_indices:
    print(f"  - {disease_names[colon_cancer_indices[0]]}")

# Breast cancer
breast_cancer_terms = ['breast']
breast_cancer_indices = []
for i, name in enumerate(disease_names):
    name_str = str(name).lower()
    for term in breast_cancer_terms:
        if term in name_str:
            breast_cancer_indices.append(i)
            break
print(f"\nBreast cancer: {len(breast_cancer_indices)} found")
if breast_cancer_indices:
    print(f"  - {disease_names[breast_cancer_indices[0]]}")

# Function to find best patient for a disease pair
def find_best_patient_for_pair(first_disease_indices, second_disease_indices, pair_name, 
                               min_rr=2.5, max_patients_to_check=10000):
    """Find patient with highest risk ratio for second disease after first disease"""
    print(f"\n{'='*80}")
    print(f"SEARCHING: {pair_name}")
    print(f"{'='*80}")
    
    good_examples = []
    
    # Find patients with first disease → second disease progression
    for patient_idx in range(min(len(E_batch), max_patients_to_check)):
        event_times = E_batch[patient_idx]
        
        # Find earliest first disease
        first_disease_ages = []
        for fd_idx in first_disease_indices:
            if event_times[fd_idx] < 51:
                first_age = event_times[fd_idx].item()
                first_disease_ages.append((fd_idx, first_age))
        
        if len(first_disease_ages) == 0:
            continue
        
        earliest_first_idx, earliest_first_age = min(first_disease_ages, key=lambda x: x[1])
        
        # Check if patient has second disease
        for sd_idx in second_disease_indices:
            if event_times[sd_idx] < 51:
                second_age = event_times[sd_idx].item()
                # Check if first disease occurs first
                if earliest_first_age <= second_age:
                    good_examples.append((patient_idx, earliest_first_idx, earliest_first_age, sd_idx, second_age))
                    break
    
    print(f"✓ Found {len(good_examples)} patients with {pair_name} progression")
    
    if len(good_examples) == 0:
        return None
    
    # Find best patient (highest risk ratio)
    best_example = None
    best_score = -1
    
    for patient_idx, first_idx, first_age, second_idx, second_age in good_examples[:500]:  # Check more patients
        t_first = int(first_age)
        pi_at_first = pi_predictions[patient_idx, :, t_first]
        pop_baseline = pi_predictions[:, :, t_first].mean(dim=0)
        
        second_pred = pi_at_first[second_idx].item()
        second_pop = pop_baseline[second_idx].item()
        second_rr = second_pred / second_pop if second_pop > 0 else 0
        
        # Only consider if RR > min_rr
        if second_rr > min_rr and second_rr > best_score:
            best_score = second_rr
            best_example = {
                'patient_idx': patient_idx,
                'first_disease_idx': first_idx,
                'first_disease_age': first_age,
                'second_disease_idx': second_idx,
                'second_disease_age': second_age,
                'risk_ratio': second_rr,
                'pair_name': pair_name
            }
    
    if best_example:
        print(f"✓ Best patient: {best_example['patient_idx']} (RR={best_example['risk_ratio']:.2f}x)")
        print(f"  First disease: {disease_names[best_example['first_disease_idx']]} at age {best_example['first_disease_age'] + 30:.0f}")
        print(f"  Second disease: {disease_names[best_example['second_disease_idx']]} at age {best_example['second_disease_age'] + 30:.0f}")
    else:
        print(f"⚠️  No patient found with RR > {min_rr}x")
    
    return best_example

# Find all three pairs
examples = {}



# 2. ASCVD → Colon Cancer
if len(ascvd_indices) > 0 and len(colon_cancer_indices) > 0:
    examples['ascvd_colon'] = find_best_patient_for_pair(
        ascvd_indices, colon_cancer_indices, "ASCVD → Colon Cancer", min_rr=2.5
    )

# 3. Breast Cancer → ASCVD
if len(breast_cancer_indices) > 0 and len(ascvd_indices) > 0:
    examples['breast_ascvd'] = find_best_patient_for_pair(
        breast_cancer_indices, ascvd_indices, "Breast Cancer → ASCVD", min_rr=2.5
    )

# Summary
print("\n" + "="*80)
print("SUMMARY: ALL EXAMPLES FOUND")
print("="*80)
for key, ex in examples.items():
    if ex:
        print(f"\n{ex['pair_name']}:")
        print(f"  Patient {ex['patient_idx']}: RR={ex['risk_ratio']:.2f}x")
        print(f"  First: {disease_names[ex['first_disease_idx']]} (age {ex['first_disease_age'] + 30:.0f})")
        print(f"  Second: {disease_names[ex['second_disease_idx']]} (age {ex['second_disease_age'] + 30:.0f})")
    else:
        print(f"\n{key}: No example found with RR > 2.5x")

print(f"\n✓ Found {sum(1 for ex in examples.values() if ex)} out of 3 examples")


In [None]:
# ============================================================================
# VISUALIZATION: Risk Trajectories for All Three Disease Pairs
# Demonstrating Multi-Disease Prediction After Initial Diagnosis
# ============================================================================

import matplotlib.pyplot as plt
import numpy as np
colon_cancer_idx=10

# Create visualizations for all found examples
# Check if we have the colon cancer example from cell 4 or from examples dictionary
if 'examples' in locals() and examples and 'ascvd_colon' in examples and examples['ascvd_colon']:
    ex = examples['ascvd_colon']
    patient_idx = ex['patient_idx']
    heart_idx = ex['first_disease_idx']
    colon_cancer_idx = ex['second_disease_idx']
    heart_age = ex['first_disease_age']
    colon_cancer_age = ex['second_disease_age']
    
    print("="*80)
    print("CREATING RISK TRAJECTORY PLOTS")
    print("="*80)
    print(f"Visualizing: {ex['pair_name']}")
    print(f"Patient {patient_idx}: ASCVD at age {heart_age + 30:.0f}, Colon Cancer at age {colon_cancer_age + 30:.0f}")
    
    # Create figure with 2 subplots
    fig, axes = plt.subplots(2, 1, figsize=(14, 10))
    
    ages = np.arange(30, 82)  # Ages 30-81
    
    # Get risk trajectories for both diseases
    heart_risk_patient = pi_predictions[patient_idx, heart_idx, :].numpy()
    heart_risk_pop = pi_predictions[:, heart_idx, :].mean(dim=0).numpy()
    
    colon_cancer_risk_patient = pi_predictions[patient_idx, colon_cancer_idx, :].numpy()
    colon_cancer_risk_pop = pi_predictions[:, colon_cancer_idx, :].mean(dim=0).numpy()
    
    # Calculate risk ratios
    heart_rr = heart_risk_patient / (heart_risk_pop + 1e-10)
    colon_cancer_rr = colon_cancer_risk_patient / (colon_cancer_risk_pop + 1e-10)
    
    # Get diagnosis ages
    heart_dx_age = heart_age + 30
    cancer_dx_age = colon_cancer_age + 30 if colon_cancer_age is not None else None
    
    # Calculate risk ratio at heart disease diagnosis
    t_heart = int(heart_age)
    rr_at_heart_dx = colon_cancer_rr[t_heart]
    
    # ===== PLOT 1: Absolute Predicted Risk =====
    ax1 = axes[0]
    
    # Plot patient risks
    ax1.plot(ages, heart_risk_patient, 'r-', linewidth=2.5, label=f'Patient Risk: {disease_names[heart_idx]}')
    ax1.plot(ages, colon_cancer_risk_patient, 'b-', linewidth=2.5, label=f'Patient Risk: {disease_names[colon_cancer_idx]}')
    
    # Plot population baselines
    ax1.plot(ages, heart_risk_pop, 'r--', linewidth=2, alpha=0.7, label=f'Population Baseline: {disease_names[heart_idx]}')
    ax1.plot(ages, colon_cancer_risk_pop, 'b--', linewidth=2, alpha=0.7, label=f'Population Baseline: {disease_names[colon_cancer_idx]}')
    
    # Add vertical line at heart disease diagnosis
    ax1.axvline(x=heart_dx_age, color='purple', linestyle=':', linewidth=2.5, 
                label=f'ASCVD Diagnosis (Age {int(heart_dx_age)})')
    
    # Add vertical line at colon cancer diagnosis (if it occurred)
    if cancer_dx_age is not None:
        ax1.axvline(x=cancer_dx_age, color='blue', linestyle='--', linewidth=2.5, 
                    label=f'Colon Cancer Diagnosis (Age {int(cancer_dx_age)})')
    
    # Add shaded regions
    ax1.axvspan(30, heart_dx_age, alpha=0.1, color='gray', label='Before ASCVD')
    if cancer_dx_age is not None:
        ax1.axvspan(heart_dx_age, cancer_dx_age, alpha=0.1, color='lightblue', label='After ASCVD, Before Colon Cancer')
        ax1.axvspan(cancer_dx_age, 80, alpha=0.1, color='lavender', label='After Colon Cancer Diagnosis')
    else:
        ax1.axvspan(heart_dx_age, 80, alpha=0.1, color='lightblue', label='After ASCVD')
    
    # Add annotation for risk ratio at ASCVD diagnosis
    annotation_text = f'Colon Cancer RR: {rr_at_heart_dx:.2f}x at ASCVD Dx'
    ax1.annotate(annotation_text, 
                xy=(heart_dx_age, colon_cancer_risk_patient[t_heart]),
                xytext=(heart_dx_age + 5, colon_cancer_risk_patient[t_heart] + 0.0005),
                fontsize=11, fontweight='bold', color='blue',
                bbox=dict(boxstyle='round,pad=0.5', facecolor='lightblue', alpha=0.7),
                arrowprops=dict(arrowstyle='->', color='blue', lw=2))
    
    ax1.set_xlabel('Age (years)', fontsize=13, fontweight='bold')
    ax1.set_ylabel('Predicted Disease Risk', fontsize=13, fontweight='bold')
    ax1.set_title(f'Patient {patient_idx}: Risk Trajectories for ASCVD and Colon Cancer\nDemonstrating Multi-Disease Prediction After Initial Diagnosis', 
                  fontsize=14, fontweight='bold', pad=15)
    ax1.legend(loc='upper left', fontsize=10, framealpha=0.9)
    ax1.grid(True, alpha=0.3)
    ax1.set_ylim(bottom=0)
    ax1.set_xlim(30, 80)
    
    # ===== PLOT 2: Risk Ratio Trajectories =====
    ax2 = axes[1]
    
    # Plot risk ratios
    ax2.plot(ages, heart_rr, 'r-', linewidth=2.5, label=f'{disease_names[heart_idx]} Risk Ratio')
    ax2.plot(ages, colon_cancer_rr, 'b-', linewidth=2.5, label=f'{disease_names[colon_cancer_idx]} Risk Ratio')
    
    # Add horizontal line at RR=1.0
    ax2.axhline(y=1.0, color='red', linestyle='--', linewidth=1.5, alpha=0.7, label='Population Average (RR=1.0)')
    
    # Add shaded region for elevated risk
    ax2.axhspan(1.0, 3.5, alpha=0.1, color='green', label='Elevated Risk (RR > 1.0)')
    
    # Add vertical line at heart disease diagnosis
    ax2.axvline(x=heart_dx_age, color='purple', linestyle=':', linewidth=2.5, 
                label=f'ASCVD Diagnosis (Age {int(heart_dx_age)})')
    
    # Add vertical line at colon cancer diagnosis (if it occurred)
    if cancer_dx_age is not None:
        ax2.axvline(x=cancer_dx_age, color='blue', linestyle='--', linewidth=2.5, 
                    label=f'Colon Cancer Diagnosis (Age {int(cancer_dx_age)})')
    
    # Add annotation for risk ratio at ASCVD diagnosis
    ax2.annotate(f'RR = {rr_at_heart_dx:.2f}x\nat ASCVD Dx', 
                xy=(heart_dx_age, rr_at_heart_dx),
                xytext=(heart_dx_age + 3, rr_at_heart_dx + 0.3),
                fontsize=12, fontweight='bold', color='blue',
                bbox=dict(boxstyle='round,pad=0.5', facecolor='yellow', alpha=0.8),
                arrowprops=dict(arrowstyle='->', color='blue', lw=2))
    
    # Add annotation for colon cancer diagnosis if it occurred
    if cancer_dx_age is not None:
        t_cancer = int(cancer_dx_age - 30)
        if t_cancer >= 0 and t_cancer < len(colon_cancer_rr):
            rr_at_cancer_dx = colon_cancer_rr[t_cancer]
            ax2.annotate(f'Colon Cancer\nDiagnosed\nAge {int(cancer_dx_age)}', 
                        xy=(cancer_dx_age, rr_at_cancer_dx),
                        xytext=(cancer_dx_age + 3, rr_at_cancer_dx + 0.5),
                        fontsize=11, fontweight='bold', color='blue',
                        bbox=dict(boxstyle='round,pad=0.5', facecolor='lightblue', alpha=0.8),
                        arrowprops=dict(arrowstyle='->', color='blue', lw=2))
    
    ax2.set_xlabel('Age (years)', fontsize=13, fontweight='bold')
    ax2.set_ylabel('Risk Ratio (Patient / Population)', fontsize=13, fontweight='bold')
    ax2.set_title('Risk Ratio Trajectories: Patient Risk Relative to Population Average', 
                  fontsize=14, fontweight='bold', pad=15)
    ax2.legend(loc='upper left', fontsize=10, framealpha=0.9)
    ax2.grid(True, alpha=0.3)
    ax2.set_ylim(0, 4.0)
    ax2.set_xlim(30, 80)
    
    plt.tight_layout()
    plt.show()
    
    print("\n✓ Plots created successfully!")
    print(f"\nKey Findings:")
    print(f"1. Heart Disease risk peaks around age {ages[np.argmax(heart_risk_patient)]:.0f}")
    print(f"2. Colon Cancer risk ratio peaks at {rr_at_heart_dx:.2f}x at heart disease diagnosis (age {int(heart_dx_age)})")
    print(f"3. Patient remains at risk for multiple diseases simultaneously")
    print(f"4. Risk trajectories demonstrate that competing risks are not mutually exclusive")
    
else:
    print("⚠️  Cannot create plots - patient data not available")
    print("   Please run Cell 4 first to identify example patient")



## 4. Explanation: Decreasing Hazards at Old Age

The reviewer expressed concern about decreasing hazards at old age. This is **NOT a model failure** but reflects real phenomena:


In [None]:
print("="*80)
print("EXPLANATION: DECREASING HAZARDS AT OLD AGE")
print("="*80)
print("\nThis is NOT a model failure but reflects:")
print("\n1. ADMINISTRATIVE CENSORING:")
print("   - All individuals censored at age 80 (standard in biobank analyses)")
print("   - Creates interval censoring that appears as declining hazard")
print("   - Limited follow-up beyond age 80 in UK Biobank")
print("\n2. COMPETING RISK OF DEATH:")
print("   - Individuals at age 75+ face high mortality risk")
print("   - Those who survive to 80 are SELECTED HEALTHY SURVIVORS")
print("   - Creates apparent risk reduction (survival bias)")
print("   - This is a REAL PHENOMENON, not a model artifact")
print("\n3. HEALTHY SURVIVOR EFFECT:")
print("   - Patients who survive to old age without disease are genuinely lower risk")
print("   - The model correctly captures this selection effect")
print("   - This is clinically meaningful: older patients without disease are healthier")
print("\nINTERPRETATION: The decreasing hazards at old age reflect both")
print("administrative censoring and the competing risk of death.")
print("This is EXPECTED and does not indicate model failure.")
