# Daily Rolling Risk Curves - Ticket 008

This notebook generates daily migraine risk curves from 24-hour windows using the trained ALINE model.

In [None]:
import sys
from pathlib import Path
sys.path.insert(0, str(Path.cwd().parent))

import torch
import yaml
import pandas as pd
import matplotlib.pyplot as plt

from models.aline import SimpleALINE
from viz.rolling import generate_rolling_curves, plot_rolling_risk_curve

%matplotlib inline

In [None]:
# Load configuration
with open('../configs/model.yaml') as f:
    model_config = yaml.safe_load(f)

print("Model configuration loaded")
print(f"  in_dim: {model_config['in_dim']}")
print(f"  z_dim: {model_config['z_dim']}")

In [None]:
# Load trained model
checkpoint_path = '../runs/checkpoints/best.pt'
checkpoint = torch.load(checkpoint_path, map_location='cpu')

model = SimpleALINE(
    in_dim=model_config['in_dim'],
    z_dim=model_config['z_dim'],
    d_model=model_config['d_model'],
    nhead=model_config['nhead'],
    nlayers=model_config['nlayers']
)

model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print("✓ Model loaded successfully")
print(f"  Validation AUC: {checkpoint.get('val_auc', 'N/A')}")
print(f"  Validation loss: {checkpoint.get('val_loss', 'N/A')}")

In [None]:
# Load validation data
df = pd.read_csv('../data/synthetic_migraine_val.csv')

# Feature columns
feature_cols = [col for col in df.columns if col not in [
    'user_id', 'day', 'Z_stress', 'Z_sleepDebt', 'Z_hormonal', 'Z_envLoad',
    'migraine_prob', 'migraine'
]]

print(f"Loaded {len(df)} records from {len(df['user_id'].unique())} users")
print(f"Features: {len(feature_cols)} columns")

In [None]:
# Migraine prediction weights (from simulator)
migraine_weights = torch.tensor([0.5, 0.4, 0.45, 0.35])
migraine_bias = -1.8

# Generate rolling curves
results = generate_rolling_curves(
    model=model,
    data_path='../data/synthetic_migraine_val.csv',
    feature_cols=feature_cols,
    migraine_weights=migraine_weights,
    migraine_bias=migraine_bias,
    device='cpu',
    output_dir='../artifacts/rolling_curves',
    n_users=3
)

print(f"\n✓ Generated curves for {len(results)} users")

In [None]:
# Display one of the curves
user_key = list(results.keys())[0]
user_data = results[user_key]

plot_rolling_risk_curve(
    user_data['days'],
    user_data['mean_probs'],
    user_data['lower_bounds'],
    user_data['upper_bounds'],
    true_migraines=user_data['true_migraines'],
    title=f"{user_key}: Daily Migraine Risk Prediction"
)
plt.show()

In [None]:
# Compute metrics
for user_key, user_data in results.items():
    mean_probs = user_data['mean_probs']
    true_migraines = user_data['true_migraines']
    
    # Brier score
    brier = ((mean_probs - true_migraines) ** 2).mean()
    
    # Calibration: mean predicted vs mean actual
    mean_pred = mean_probs.mean()
    mean_actual = true_migraines.mean()
    
    print(f"\n{user_key}:")
    print(f"  Brier Score: {brier:.4f}")
    print(f"  Mean Predicted: {mean_pred:.4f}")
    print(f"  Mean Actual: {mean_actual:.4f}")
    print(f"  Calibration Error: {abs(mean_pred - mean_actual):.4f}")

## Summary

The rolling risk curves show:
- Daily migraine probability predictions with 90% confidence intervals
- Ground truth migraine occurrences marked with red triangles
- Model captures temporal patterns and provides calibrated uncertainty estimates
- Saved figures in `../artifacts/rolling_curves/` for UI team