# Stroke Prediction — Clean ML Pipeline
This notebook builds a reproducible **scikit-learn** pipeline for a stroke dataset.
It includes preprocessing, class imbalance handling, model comparison, ROC/PR curves, and interpretability.

In [None]:

import numpy as np, pandas as pd, matplotlib.pyplot as plt
from pathlib import Path

from sklearn.model_selection import train_test_split, StratifiedKFold, cross_val_score
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.impute import SimpleImputer
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import (roc_auc_score, average_precision_score, roc_curve, precision_recall_curve,
                             confusion_matrix, classification_report)
from sklearn.inspection import permutation_importance

# --- Config ---
use_demo = True   # set to False to use data/stroke.csv
data_dir = Path('../data')
demo_path = data_dir/'sample.csv'
real_path = data_dir/'stroke.csv'  # provide your own

# Load
path = demo_path if use_demo else real_path
df = pd.read_csv(path)
df.head()


In [None]:

# Basic EDA-like checks
print('Shape:', df.shape)
print('Columns:', list(df.columns))
print(df.isna().sum().sort_values(ascending=False).head(10))
print(df['stroke'].value_counts(normalize=True).rename('class balance'))


In [None]:

# Split
X = df.drop(columns=['stroke'])
y = df['stroke'].astype(int)

cat_cols = X.select_dtypes(include=['object']).columns.tolist()
num_cols = X.select_dtypes(exclude=['object']).columns.tolist()

numeric_pre = Pipeline([
    ('impute', SimpleImputer(strategy='median')),
    ('scale', StandardScaler())
])
categorical_pre = Pipeline([
    ('impute', SimpleImputer(strategy='most_frequent')),
    ('onehot', OneHotEncoder(handle_unknown='ignore'))
])

preprocess = ColumnTransformer([
    ('num', numeric_pre, num_cols),
    ('cat', categorical_pre, cat_cols)
])

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25,
                                                    random_state=42, stratify=y)

print('Train size:', X_train.shape, ' Test size:', X_test.shape)


In [None]:

# Models
logreg = Pipeline([('prep', preprocess),
                   ('clf', LogisticRegression(max_iter=200, class_weight='balanced'))])

rf = Pipeline([('prep', preprocess),
               ('clf', RandomForestClassifier(n_estimators=400, random_state=42, class_weight='balanced'))])

models = {'LogReg': logreg, 'RandomForest': rf}

# Fit & evaluate
results = {}
for name, model in models.items():
    model.fit(X_train, y_train)
    proba = model.predict_proba(X_test)[:,1]
    preds = (proba >= 0.5).astype(int)
    auc = roc_auc_score(y_test, proba)
    ap = average_precision_score(y_test, proba)
    results[name] = {'auc': auc, 'ap': ap, 'preds': preds, 'proba': proba, 'model': model}
    print(f"{name}: ROC-AUC={auc:.3f} | PR-AUC={ap:.3f}")


In [None]:

# ROC and PR curves
plt.figure(figsize=(6,4))
for name, r in results.items():
    fpr, tpr, _ = roc_curve(y_test, r['proba'])
    plt.plot(fpr, tpr, label=f"{name} (AUC={r['auc']:.3f})")
plt.plot([0,1],[0,1],'--', lw=1)
plt.xlabel('FPR'); plt.ylabel('TPR'); plt.title('ROC Curves'); plt.legend(); plt.show()

plt.figure(figsize=(6,4))
for name, r in results.items():
    prec, rec, _ = precision_recall_curve(y_test, r['proba'])
    plt.plot(rec, prec, label=f"{name} (AP={r['ap']:.3f})")
plt.xlabel('Recall'); plt.ylabel('Precision'); plt.title('Precision-Recall Curves'); plt.legend(); plt.show()


In [None]:

# Confusion matrix and report for the best model by PR-AUC (class imbalance aware)
best = max(results.items(), key=lambda kv: kv[1]['ap'])[1]
best_name = max(results.items(), key=lambda kv: kv[1]['ap'])[0]
print('Best by PR-AUC:', best_name)
print(confusion_matrix(y_test, best['preds']))
print(classification_report(y_test, best['preds']))


In [None]:

# Permutation importance (on a held-out set) for the best model
best_model = best['model']
# Pull the random forest from the pipeline for importances (logreg won't have .feature_importances_)
if best_name == 'RandomForest':
    # get feature names after preprocessing
    ohe = best_model.named_steps['prep'].named_transformers_['cat'].named_steps['onehot']
    num_cols = best_model.named_steps['prep'].transformers_[0][2]
    cat_cols = ohe.get_feature_names_out(best_model.named_steps['prep'].transformers_[1][2])
    feat_names = np.concatenate([num_cols, cat_cols])

    # permutation importance
    r = permutation_importance(best_model, X_test, y_test, n_repeats=5, random_state=42)
    importances = r.importances_mean
    order = np.argsort(importances)[::-1][:15]

    plt.figure(figsize=(7,5))
    plt.barh(range(len(order)), importances[order][::-1])
    plt.yticks(range(len(order)), [feat_names[i] for i in order][::-1])
    plt.xlabel('Mean decrease in score'); plt.title('Permutation Importance (top 15)'); plt.tight_layout(); plt.show()
else:
    print('Permutation importance shown for RandomForest only (tree-based).')
