# Model Validity & Learning Analysis

**Addresses**: R2 (Temporal Leakage) and R3 Q3 (Washout Windows)

This notebook contains analyses focused on **model validity** and **learning dynamics**, demonstrating that the model learns appropriately and washout periods correctly prevent temporal leakage.

## Purpose

These analyses demonstrate:
1. **Model Learning**: How the model learns to distinguish high-risk from lower-risk patients
2. **Washout Validity**: Whether washout periods correctly prevent temporal leakage (R2)
3. **Signature Dynamics**: How patient-specific parameters (lambda) change as models are trained with more data
4. **Biological Validity**: Whether signature responses align with biological pathways

## Main Approach: Pooled Retrospective

All analyses use the `pooled_retrospective` approach by default, which:
- Uses phi trained externally and validated with LOO tests
- Represents clinically implementable behavior
- Uses pi from: `enrollment_predictions_fixedphi_RETROSPECTIVE_pooled/pi_enroll_fixedphi_sex_FULL.pt`


---

## SECTION 1: PREDICTION DROPS ANALYSIS

**Purpose**: Understand why predictions change between washout periods

Analyzes why predictions drop between 0-year and 1-year washout, focusing on precursor diseases like hypercholesterolemia.


In [None]:
# ============================================================================
# ANALYZE PREDICTION DROPS
# ============================================================================
"""
Analyzes why predictions drop between 0-year and 1-year washout
Focuses on hypercholesterolemia and other precursor diseases
Results saved to: results/analysis/prediction_drops_*.csv
"""

import sys
from pathlib import Path

# Add parent directory to path so we can import scripts
script_dir = Path('/Users/sarahurbut/aladynoulli2/pyScripts/new_oct_revision/new_notebooks')
sys.path.insert(0, str(script_dir))

%run /Users/sarahurbut/aladynoulli2/pyScripts/new_oct_revision/new_notebooks/analyze_prediction_drops.py --disease ASCVD


Loading data...
ANALYZING PREDICTION DROPS FOR: ASCVD
'ASCVD' is a disease group. Finding individual diseases...
  Found: Myocardial infarction at index 112
  Found: Coronary atherosclerosis at index 114
  Found: Other acute and subacute forms of ischemic heart disease at index 116
  Found: Unstable angina (intermediate coronary syndrome) at index 111
  Found: Angina pectoris at index 113
  Found: Other chronic ischemic heart disease, unspecified at index 115
Found 6 disease(s) for 'ASCVD'

Analyzing 400000 patients...

Collecting predictions and outcomes...
Collected 400000 patients with both 0yr and 1yr predictions

NOTE: Prevalent case exclusion (matches evaluation function logic):
  - For single diseases: Patients with that disease before prediction time are excluded
  - For disease groups (like ASCVD): Prevalent cases are NOT excluded
    (patients can have multiple events in the group, e.g., CAD then MI)
  - This matches the evaluation function's approach for disease groups

Pred

In [None]:
# ============================================================================
# VISUALIZE PREDICTION DROPS
# ============================================================================
"""
Creates plots for prediction drops analysis
Plots saved to: results/analysis/plots/
"""

import sys
from pathlib import Path

# Add parent directory to path so we can import scripts
script_dir = Path('/Users/sarahurbut/aladynoulli2/pyScripts/new_oct_revision/new_notebooks')
sys.path.insert(0, str(script_dir))

%run /Users/sarahurbut/aladynoulli2/pyScripts/new_oct_revision/new_notebooks/visualize_prediction_drops.py --disease ASCVD


VISUALIZING PREDICTION DROPS ANALYSIS: ASCVD

Loading results...
✓ Loaded 3 result files

Creating plots...
✓ Saved plot to: /Users/sarahurbut/aladynoulli2/pyScripts/new_oct_revision/new_notebooks/results/analysis/plots/hyperchol_comparison_ASCVD.png
✓ Saved plot to: /Users/sarahurbut/aladynoulli2/pyScripts/new_oct_revision/new_notebooks/results/analysis/plots/precursor_comparison_ASCVD.png
✓ Saved plot to: /Users/sarahurbut/aladynoulli2/pyScripts/new_oct_revision/new_notebooks/results/analysis/plots/precursor_ratios_ASCVD.png

VISUALIZATION COMPLETE

Plots saved to: /Users/sarahurbut/aladynoulli2/pyScripts/new_oct_revision/new_notebooks/results/analysis/plots


In [None]:
# ============================================================================
# VISUALIZE MODEL LEARNING (KEY INSIGHT FIGURE)
# ============================================================================
"""
Creates a figure showing the key insight: Model learns to distinguish between
high-risk and lower-risk hypercholesterolemia patients.

Non-droppers (predictions stay high) have HIGHER event rates → Model correctly
identifies high-risk patients. This shows the model is learning and calibrating.
"""
import sys
from pathlib import Path

# Add parent directory to path so we can import scripts
script_dir = Path('/Users/sarahurbut/aladynoulli2/pyScripts/new_oct_revision/new_notebooks')
sys.path.insert(0, str(script_dir))

%run /Users/sarahurbut/aladynoulli2/pyScripts/new_oct_revision/new_notebooks/visualize_model_learning.py --disease ASCVD


CREATING MODEL LEARNING FIGURES
✓ Saved figure to: /Users/sarahurbut/aladynoulli2/pyScripts/new_oct_revision/new_notebooks/results/analysis/plots/model_learning_hyperchol_ASCVD.png
✓ Saved full comparison figure to: /Users/sarahurbut/aladynoulli2/pyScripts/new_oct_revision/new_notebooks/results/analysis/plots/model_learning_full_comparison_ASCVD.png
✓ Saved multiple precursors figure to: /Users/sarahurbut/aladynoulli2/pyScripts/new_oct_revision/new_notebooks/results/analysis/plots/model_learning_multiple_precursors_ASCVD.png

COMPLETE


---

## SECTION 2: MI WASHOUT ANALYSIS

**Purpose**: Validate washout periods using signature-based learning

Analyzes MI (Myocardial Infarction) washout with signature-based learning to understand how the model learns from different time periods.


ewer re

In [5]:
# ============================================================================
# MI WASHOUT ANALYSIS: SIGNATURE-BASED LEARNING
# ============================================================================
"""
Analyzes MI (Myocardial Infarction) washout with signature-based learning.

For each patient, tracks:
- 3 MODELS: m0t9, m5t9, m9t9 (all predict at t9, trained to t0, t5, t9)
- 3 TIME PERIODS: 
  1) Baseline (before t0/enrollment)
  2) Interval t0-t5
  3) Interval t5-t9
- For each period: MI status and Signature 5 precursor diseases

Categorizes washout based on what developed in intervals (not baseline).
"""

import subprocess
import sys
from pathlib import Path

script_dir = Path('/Users/sarahurbut/aladynoulli2/pyScripts/new_oct_revision/new_notebooks')

result = subprocess.run([
    sys.executable,
    str(script_dir / 'analyze_mi_washout_signature.py'),
    '--start_idx', '0',
    '--end_idx', '10000'
], capture_output=True, text=True)

print(result.stdout)
if result.stderr:
    print("STDERR:", result.stderr)
if result.returncode != 0:
    print(f"\n⚠️  WARNING: Script exited with return code {result.returncode}")


MI WASHOUT ANALYSIS WITH SIGNATURE-BASED LEARNING
Batch: 0-10000

Loading essentials...
✓ Found MI at index 112: Myocardial infarction

Loading cluster assignments...
✓ Loaded clusters: 348 diseases

✓ MI belongs to Signature 5
✓ Found 7 diseases in Signature 5
  Examples: ['Hypercholesterolemia', 'Unstable angina (intermediate coronary syndrome)', 'Myocardial infarction', 'Angina pectoris', 'Coronary atherosclerosis']

Loading data batch 0-10000...

Loading pi batches for offsets 0-9...
✓ Loaded 10 pi batches

Loading model checkpoints to extract lambda...

ANALYZING MI WASHOUT

Analyzing 10000 patients...
MI index: 112
Signature 5 has 7 diseases

✓ Saved results to: results/analysis/mi_washout_analysis_batch_0_10000.csv

SUMMARY STATISTICS
Total patients analyzed: 10000

Washout categories:
washout_category
neither         8694
accurate        1028
conservative     278
Name: count, dtype: int64

MI status at t9: 435 patients (4.3%)

Patients with Signature 5 precursors at t9: 1600 (1

In [None]:
import sys
from pathlib import Path

# Add parent directory to path so we can import scripts
script_dir = Path('/Users/sarahurbut/aladynoulli2/pyScripts/new_oct_revision/new_notebooks')
sys.path.insert(0, str(script_dir))

%run /Users/sarahurbut/aladynoulli2/pyScripts/new_oct_revision/new_notebooks/visualize_mi_washout_signature.py


VISUALIZING MI WASHOUT ANALYSIS
✓ Saved figure to: /Users/sarahurbut/aladynoulli2/pyScripts/new_oct_revision/new_notebooks/results/analysis/plots/mi_washout_signature_analysis.png

VISUALIZATION COMPLETE


---

## SECTION 3: AGE OFFSET SIGNATURE ANALYSIS

**Purpose**: Understand how patient-specific parameters (lambda) change as models are trained with more data

This analysis shows how the model learns and adapts as more data becomes available, distinguishing between:
- Conservative washout (with outcome events)
- Accurate washout (with precursor only)
- Model refinement (without either)


In [7]:
# ============================================================================
# ANALYZE AGE OFFSET SIGNATURE CHANGES
# ============================================================================
"""
Analyzes how predictions and signature loadings change across age offsets (t0-t9).

For patients with specific precursor diseases, tracks:
1. How their predictions change across offsets 0-9
2. Which signatures/clusters are most impacted
3. Which precursor diseases drive which signature changes

This shows how the model learns and adapts as more data becomes available.
"""

# Run analysis for key precursor diseases
import subprocess
import sys
from pathlib import Path

script_dir = Path('/Users/sarahurbut/aladynoulli2/pyScripts/new_oct_revision/new_notebooks')

result = subprocess.run([
    sys.executable,
    str(script_dir / 'analyze_age_offset_signatures.py'),
    '--approach', 'pooled_retrospective',
    '--target_disease', 'ASCVD',
    '--start_idx', '0',
    '--end_idx', '10000'
], capture_output=True, text=True)

print(result.stdout)
if result.stderr:
    print("STDERR:", result.stderr)
if result.returncode != 0:
    print(f"\n⚠️  WARNING: Script exited with return code {result.returncode}")


ANALYZING AGE OFFSET SIGNATURE CHANGES

Approach: pooled_retrospective
Batch: 0-10000
Target disease: ASCVD
Precursor diseases: ['Hypercholesterolemia', 'Essential hypertension', 'Type 2 diabetes', 'Atrial fibrillation and flutter', 'Obesity', 'Chronic Kidney Disease, Stage III', 'Rheumatoid arthritis', 'Sleep apnea', 'Peripheral vascular disease, unspecified']

Loading essentials...
Loading cluster assignments...
  ✓ Loaded clusters as numpy array from: /Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/initial_clusters_400k.pt
  Cluster shape: (348,)
✓ Loaded clusters: 348 diseases, 20 clusters

Loading data batch 0-10000...

Loading pi batches for offsets 0-9...
  Loading offset 0...
  Loading offset 1...
  Loading offset 2...
  Loading offset 3...
  Loading offset 4...
  Loading offset 5...
  Loading offset 6...
  Loading offset 7...
  Loading offset 8...
  Loading offset 9...
✓ Loaded 10 pi batches

Loading model checkpoints to extract lambda (patient-specifi

## Summary: Age Offset Signature Analysis

**Question:** When models are trained with different amounts of data (washout periods), how do patient-specific parameters (lambda) change, and does this reflect conservative vs. accurate washout?

**Findings:**

1. **Conservative washout (with outcome events):**
   - Patients who had ASCVD events during washout
   - Signature 5 (cardiovascular cluster) shows large positive lambda changes (+0.587 for hypercholesterolemia)
   - Model learns from patients who already had outcomes

2. **Accurate washout (with precursor only):**
   - Patients with precursors (e.g., hypercholesterolemia) but no ASCVD outcome during washout
   - Signature 5 shows moderate positive lambda changes (+0.305)
   - Model learns from pre-clinical signals (risk factors before outcomes)

3. **Model refinement (without either):**
   - Patients with neither precursor nor outcome
   - Small negative lambda changes (-0.053)
   - Model becomes more conservative/refined

**Interpretation:**
- The model distinguishes between:
  - Real conditions (outcomes) → large changes
  - Pre-clinical signals (precursors) → moderate changes
  - Neither → small/negative changes
- This validates washout accuracy: the model learns from legitimate risk factors, not just future outcomes
- Signature 5 correctly responds to cardiovascular precursors even when outcomes haven't occurred yet

**Conclusion:** This pattern supports model validity and washout accuracy. The model learns appropriately from pre-clinical signals, which is the intended behavior for accurate washout.
