# 04. Fusion vs. Blind Model Comparison
Run this notebook **after** `src/train_fusion.py` is complete.

### Goals:
1. Load both Trained Models.
2. Compare AUROC Scores.
3. Investigate if Fusion fixed the "Digitalis" confusion.

In [1]:
import os
import sys
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score, roc_curve
from torch.utils.data import DataLoader

sys.path.append(os.path.abspath('..'))

from src.models.resnet1d import resnet1d50
from src.models.resnet1d_fusion import resnet1d_fusion
from src.data.dataset_fusion import PTBXLDatasetFusion

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DATA_DIR = '../data/ptb-xl'

## 1. Load Data (Validation Set Only)

In [None]:
# We use the DatasetFusion for both, just ignoring context for the blind model if needed
# Note: Blind model expects ONE input, so we'll need a wrapper or loop carefulness

# Replicate Split Logic (Fold 1)
from src.train_fusion import load_metadata
from sklearn.model_selection import StratifiedGroupKFold

df = load_metadata(DATA_DIR)

sgkf = StratifiedGroupKFold(n_splits=10, shuffle=True, random_state=42)
splits = list(sgkf.split(df.index, df['strat_target'], df['patient_id']))
_, val_idx = splits[0] # Fold 1
val_df = df.iloc[val_idx]

print(f"Validation Set: {len(val_df)} patients")

val_ds = PTBXLDatasetFusion(val_df, DATA_DIR, sampling_rate=500, use_ram_cache=True)
val_loader = DataLoader(val_ds, batch_size=32, shuffle=False)

Validation Set: 2180 patients
Loading 2180 records into RAM (Fusion Mode)...


 18%|██████████████████▏                                                                                    | 385/2180 [00:17<01:36, 18.62it/s]

## 2. Load Models

In [None]:
# Blind Model
blind_model = resnet1d50(num_classes=5).to(DEVICE)
try:
    blind_model.load_state_dict(torch.load('../src/models/resnet1d_best.pth', map_location=DEVICE))
    print("Loaded Blind Model.")
except:
    print("Blind model not found.")

# Fusion Model
fusion_model = resnet1d_fusion(num_classes=5).to(DEVICE)
try:
    fusion_model.load_state_dict(torch.load('../src/models/resnet1d_fusion_best.pth', map_location=DEVICE))
    print("Loaded Fusion Model.")
except:
    print("Fusion model not found! (Did you run train_fusion.py?)")

## 3. Compare Performance

In [None]:
blind_model.eval()
fusion_model.eval()

preds_blind, preds_fusion, targets = [], [], []

with torch.no_grad():
    for x_img, x_ctx, y in val_loader:
        x_img, x_ctx = x_img.to(DEVICE), x_ctx.to(DEVICE)
        
        # Blind Inference
        out_b = blind_model(x_img)
        preds_blind.append(torch.sigmoid(out_b).cpu().numpy())
        
        # Fusion Inference
        out_f = fusion_model(x_img, x_ctx)
        preds_fusion.append(torch.sigmoid(out_f).cpu().numpy())
        
        targets.append(y.numpy())
        
p_b = np.concatenate(preds_blind)
p_f = np.concatenate(preds_fusion)
y_true = np.concatenate(targets)

auc_b = roc_auc_score(y_true, p_b, average='macro')
auc_f = roc_auc_score(y_true, p_f, average='macro')

print(f"Blind AUROC: {auc_b:.4f}")
print(f"Fusion AUROC: {auc_f:.4f}")
print(f"Improvement: {(auc_f - auc_b)*100:.2f}%")