In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import accuracy_score, f1_score, recall_score, confusion_matrix


In [None]:

# Conformal Prediction Helper Functions
def compute_nonconformity(model, X_val, y_val, alpha=0.8):
    """Compute nonconformity scores using MRI (ground truth) and weighted triage labels."""
    model.eval()
    with torch.no_grad():
        logits = model(torch.tensor(X_val, dtype=torch.float32))
        probs = torch.softmax(logits, dim=1).numpy()
        preds = np.argmax(probs, axis=1)
        
        # Nonconformity score: (weighted difference from MRI and triage labels)
        scores = alpha * np.abs(probs[:, 1] - y_val[:, 0]) + (1 - alpha) * np.abs(probs[:, 1] - y_val[:, 1])
    return scores

def calibrate_conformal_threshold(scores, quantile=0.95):
    """Determine the threshold from calibration scores."""
    return np.quantile(scores, quantile)

def predict_with_conformal(model, X_test, threshold):
    """Predict with conformal uncertainty bounds."""
    model.eval()
    with torch.no_grad():
        logits = model(torch.tensor(X_test, dtype=torch.float32))
        probs = torch.softmax(logits, dim=1).numpy()
    
    # Generate prediction sets based on conformal threshold
    prediction_sets = [
        {"Stroke"} if p[1] > threshold else {"Stroke", "No-Stroke"}
        for p in probs
    ]
    return prediction_sets, np.argmax(probs, axis=1), probs

# Load Data (Placeholder: Replace with actual extracted features)
X_train, y_train = np.random.rand(100, 4096), np.random.randint(0, 2, (100, 2))  # Early-stage data
X_val, y_val = np.random.rand(50, 4096), np.random.randint(0, 2, (50, 2))  # Middle-stage calibration data
X_test, y_test = np.random.rand(30, 4096), np.random.randint(0, 2, (30,))  # Newest-stage test data

# Train Model
model = StrokeClassifier(input_dim=4096)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

def train_model(model, X_train, y_train, epochs=10):
    model.train()
    for epoch in range(epochs):
        optimizer.zero_grad()
        logits = model(torch.tensor(X_train, dtype=torch.float32))
        loss = criterion(logits, torch.tensor(np.argmax(y_train, axis=1), dtype=torch.long))
        loss.backward()
        optimizer.step()
train_model(model, X_train, y_train)

# Compute nonconformity scores & calibrate threshold
nonconformity_scores = compute_nonconformity(model, X_val, y_val)
threshold = calibrate_conformal_threshold(nonconformity_scores)

# Predict on new cases with conformal confidence sets
conformal_predictions, predicted_labels, predicted_probs = predict_with_conformal(model, X_test, threshold)

# Evaluate Model
accuracy = accuracy_score(y_test, predicted_labels)
f1 = f1_score(y_test, predicted_labels)
recall = recall_score(y_test, predicted_labels)
conf_matrix = confusion_matrix(y_test, predicted_labels)

# Conformal Prediction Metrics
def compute_coverage(prediction_sets, y_test):
    """Compute coverage: fraction of times the true label is in the prediction set."""
    correct = sum(y_test[i] in pred_set for i, pred_set in enumerate(prediction_sets))
    return correct / len(y_test)

def compute_efficiency(prediction_sets):
    """Compute efficiency: average size of the prediction sets."""
    return np.mean([len(pred_set) for pred_set in prediction_sets])

coverage = compute_coverage(conformal_predictions, y_test)
efficiency = compute_efficiency(conformal_predictions)

# Output Results
for i, pred_set in enumerate(conformal_predictions[:5]):
    print(f"Test Case {i+1}: Prediction Set = {pred_set}")

print("\nModel Evaluation:")
print(f"Accuracy: {accuracy:.4f}")
print(f"F1 Score: {f1:.4f}")
print(f"Recall: {recall:.4f}")
print("Confusion Matrix:")
print(conf_matrix)

print("\nConformal Prediction Evaluation:")
print(f"Coverage: {coverage:.4f}")
print(f"Efficiency (Avg. Prediction Set Size): {efficiency:.4f}")
