# End-to-End MIMIC SDoH Benchmarking Notebook

This notebook implements a reproducible benchmarking workflow for extracting SDoH/SBDH labels from MIMIC-III discharge summaries using MIMIC-SBDH expert annotations.

## A. Setup

In [None]:
# Install dependencies (Colab only)
import sys
import subprocess
from pathlib import Path

if 'google.colab' in sys.modules:
    subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-r', 'requirements.txt'])

# Standard imports
import json
import logging
import os
import random
import time
import importlib.util

import numpy as np
import pandas as pd
import torch

# Reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# Logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s | %(levelname)s | %(message)s')

# Environment info
print('Python:', sys.version)
print('Torch:', torch.__version__)
print('CUDA available:', torch.cuda.is_available())

# Add repo root to sys.path
repo_root = Path.cwd()
if (repo_root / 'src').exists():
    sys.path.append(str(repo_root))


## B. Configuration

In [None]:
from dataclasses import dataclass
from typing import Optional

@dataclass
class BenchmarkConfig:
    mimic_root: str = r"C:\Users\Terry Yu\Documents\mimic-iii-clinical-database-1.4"
    sbdh_path: str = r"D:\Social Determinants Research\MIMIC DATASETS\MIMIC-SBDH.csv"
    output_dir: str = "outputs"
    category_filter: Optional[str] = "Discharge summary"  # set to None for all notes
    task_type: str = "multilabel"  # 'multilabel' or 'binary'
    label_columns: Optional[list[str]] = None  # auto-detect if None
    chunksize: int = 50_000
    max_length: int = 256
    text_policy: str = "truncate"  # 'truncate' or 'sliding'
    train_size: float = 0.7
    val_size: float = 0.15
    test_size: float = 0.15

config = BenchmarkConfig()
print(config)

# Optional: Google Drive mounting for Colab
# from google.colab import drive
# drive.mount('/content/drive')


## C. Data Loading & Dataset Construction

In [None]:
from pathlib import Path
from src.data_access.real_loader import RealDatasetConfig, load_mimic_sbdh_dataset

output_dir = Path(config.output_dir)
(output_dir / 'splits').mkdir(parents=True, exist_ok=True)
(output_dir / 'metrics').mkdir(parents=True, exist_ok=True)
(output_dir / 'figures').mkdir(parents=True, exist_ok=True)
(output_dir / 'cost').mkdir(parents=True, exist_ok=True)

real_config = RealDatasetConfig(
    mimic_root=Path(config.mimic_root),
    sbdh_path=Path(config.sbdh_path),
    category_filter=config.category_filter,
    chunksize=config.chunksize,
)

# Load data (chunked)
try:
    dataset = load_mimic_sbdh_dataset(real_config)
except Exception as exc:
    raise RuntimeError(
        'Failed to load MIMIC/SBDH data. Ensure paths are correct and labels include SUBJECT_ID + HADM_ID (preferred) '
        'or SUBJECT_ID + NOTE_ID.
'
        f'Error: {exc}'
    )

label_columns = [
    col for col in dataset.columns
    if col not in {'subject_id', 'hadm_id', 'note_id', 'text'}
]
if config.label_columns:
    label_columns = config.label_columns

print('Loaded rows:', len(dataset))
print('Detected labels:', label_columns)

# Dataset summary
summary = {
    'n_notes': int(len(dataset)),
    'n_subjects': int(dataset['subject_id'].nunique()),
    'n_hadm': int(dataset['hadm_id'].nunique()) if 'hadm_id' in dataset.columns else None,
    'label_prevalence': {},
}
for label in label_columns:
    summary['label_prevalence'][label] = float(dataset[label].mean())

summary_path = output_dir / 'dataset_summary.json'
summary_path.write_text(json.dumps(summary, indent=2))
print('Wrote summary to', summary_path)


## D. Task Operationalisation & Splits

In [None]:
from sklearn.model_selection import GroupShuffleSplit
from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit

X = dataset['text'].fillna('')
y = dataset[label_columns].astype(int)
subjects = dataset['subject_id']

# Subject-level labels for leakage-free splits
subject_labels = dataset.groupby('subject_id')[label_columns].max().reset_index()

if config.task_type == 'multilabel':
    splitter = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=config.test_size, random_state=SEED)
    subj_train_val_idx, subj_test_idx = next(splitter.split(subject_labels['subject_id'], subject_labels[label_columns]))
    train_val_subjects = subject_labels.loc[subj_train_val_idx, 'subject_id']
    test_subjects = subject_labels.loc[subj_test_idx, 'subject_id']

    val_splitter = MultilabelStratifiedShuffleSplit(
        n_splits=1,
        test_size=config.val_size / (config.train_size + config.val_size),
        random_state=SEED,
    )
    subj_train_idx, subj_val_idx = next(
        val_splitter.split(train_val_subjects, subject_labels.loc[subj_train_val_idx, label_columns])
    )
    train_subjects = train_val_subjects.iloc[subj_train_idx]
    val_subjects = train_val_subjects.iloc[subj_val_idx]
else:
    gss = GroupShuffleSplit(n_splits=1, test_size=config.test_size, random_state=SEED)
    subj_train_val_idx, subj_test_idx = next(
        gss.split(subject_labels['subject_id'], subject_labels[label_columns].sum(axis=1), groups=subject_labels['subject_id'])
    )
    train_val_subjects = subject_labels.loc[subj_train_val_idx, 'subject_id']
    test_subjects = subject_labels.loc[subj_test_idx, 'subject_id']

    gss_val = GroupShuffleSplit(
        n_splits=1,
        test_size=config.val_size / (config.train_size + config.val_size),
        random_state=SEED,
    )
    subj_train_idx, subj_val_idx = next(
        gss_val.split(train_val_subjects, train_val_subjects, groups=train_val_subjects)
    )
    train_subjects = train_val_subjects.iloc[subj_train_idx]
    val_subjects = train_val_subjects.iloc[subj_val_idx]

train_df = dataset[dataset['subject_id'].isin(train_subjects)].copy()
val_df = dataset[dataset['subject_id'].isin(val_subjects)].copy()
test_df = dataset[dataset['subject_id'].isin(test_subjects)].copy()

train_df[['subject_id', *label_columns]].to_csv(output_dir / 'splits' / 'train.csv', index=False)
val_df[['subject_id', *label_columns]].to_csv(output_dir / 'splits' / 'val.csv', index=False)
test_df[['subject_id', *label_columns]].to_csv(output_dir / 'splits' / 'test.csv', index=False)

print('Split sizes:', len(train_df), len(val_df), len(test_df))


## E. Baseline Models (Traditional ML)

In [None]:
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.multiclass import OneVsRestClassifier
from sklearn.metrics import f1_score, roc_auc_score
import joblib

if importlib.util.find_spec('xgboost'):
    import xgboost as xgb
    has_xgb = True
else:
    has_xgb = False

vectorizer = TfidfVectorizer(
    max_features=40000,
    ngram_range=(1, 2),
    min_df=2,
    lowercase=True,
)

X_train = vectorizer.fit_transform(train_df['text'])
X_val = vectorizer.transform(val_df['text'])
X_test = vectorizer.transform(test_df['text'])

baseline_results = []

models = {
    'LogReg': LogisticRegression(max_iter=200, class_weight='balanced', solver='saga', n_jobs=-1),
    'RandomForest': RandomForestClassifier(n_estimators=200, class_weight='balanced', n_jobs=-1),
}
if has_xgb:
    models['XGBoost'] = xgb.XGBClassifier(
        n_estimators=200,
        learning_rate=0.1,
        max_depth=6,
        subsample=0.8,
        colsample_bytree=0.8,
        eval_metric='logloss',
        tree_method='hist',
    )


def count_params(clf):
    total = 0
    for est in getattr(clf, 'estimators_', []):
        if hasattr(est, 'coef_'):
            total += est.coef_.size
            total += est.intercept_.size
        elif hasattr(est, 'feature_importances_'):
            total += est.feature_importances_.size
    return total

for name, model in models.items():
    clf = OneVsRestClassifier(model)
    start = time.time()
    clf.fit(X_train, train_df[label_columns])
    train_time = time.time() - start

    test_pred = clf.predict(X_test)
    macro_f1 = f1_score(test_df[label_columns], test_pred, average='macro', zero_division=0)
    micro_f1 = f1_score(test_df[label_columns], test_pred, average='micro', zero_division=0)

    try:
        val_scores = clf.predict_proba(X_val)
        auc = roc_auc_score(val_df[label_columns], val_scores, average='macro')
    except Exception:
        auc = float('nan')

    start_inf = time.perf_counter()
    _ = clf.predict(X_test[:100])
    inf_latency = time.perf_counter() - start_inf

    baseline_results.append({
        'model': name,
        'macro_f1': macro_f1,
        'micro_f1': micro_f1,
        'auroc_macro': auc,
        'train_time_s': train_time,
        'inference_100_s': inf_latency,
        'param_count': count_params(clf),
    })

    model_dir = output_dir / 'models'
    model_dir.mkdir(exist_ok=True)
    joblib.dump({'vectorizer': vectorizer, 'model': clf}, model_dir / f'{name}.joblib')

print(pd.DataFrame(baseline_results))


## F. Transformer Models (Clinical NLP)

In [None]:
from datasets import Dataset as HFDataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer

transformer_models = {
    'BERT': 'bert-base-uncased',
    'BioBERT': 'dmis-lab/biobert-base-cased-v1.2',
    'ClinicalBERT': 'emilyalsentzer/Bio_ClinicalBERT',
}

train_texts = train_df['text'].tolist()
val_texts = val_df['text'].tolist()
test_texts = test_df['text'].tolist()

train_labels = train_df[label_columns].astype(int).values
val_labels = val_df[label_columns].astype(int).values
test_labels = test_df[label_columns].astype(int).values

transformer_results = []
transformer_artifacts = {}

def tokenize_function(tokenizer, texts):
    return tokenizer(texts, padding='max_length', truncation=True, max_length=config.max_length)


def optimize_thresholds(labels, probs):
    thresholds = []
    for i in range(labels.shape[1]):
        best_thresh = 0.5
        best_f1 = -1
        for thresh in np.linspace(0.1, 0.9, 9):
            preds = (probs[:, i] >= thresh).astype(int)
            f1 = f1_score(labels[:, i], preds, zero_division=0)
            if f1 > best_f1:
                best_f1 = f1
                best_thresh = thresh
        thresholds.append(best_thresh)
    return np.array(thresholds)

for name, checkpoint in transformer_models.items():
    tokenizer = AutoTokenizer.from_pretrained(checkpoint)
    train_enc = tokenize_function(tokenizer, train_texts)
    val_enc = tokenize_function(tokenizer, val_texts)
    test_enc = tokenize_function(tokenizer, test_texts)

    train_ds = HFDataset.from_dict({**train_enc, 'labels': train_labels})
    val_ds = HFDataset.from_dict({**val_enc, 'labels': val_labels})
    test_ds = HFDataset.from_dict({**test_enc, 'labels': test_labels})

    model = AutoModelForSequenceClassification.from_pretrained(
        checkpoint,
        num_labels=len(label_columns),
        problem_type='multi_label_classification',
    )

    args = TrainingArguments(
        output_dir=str(output_dir / 'models_transformers' / name),
        per_device_train_batch_size=4,
        per_device_eval_batch_size=4,
        num_train_epochs=1,
        evaluation_strategy='epoch',
        save_strategy='epoch',
        load_best_model_at_end=True,
        metric_for_best_model='eval_loss',
        logging_steps=50,
        report_to='none',
        seed=SEED,
    )

    def compute_metrics(eval_pred):
        logits, labels = eval_pred
        probs = torch.sigmoid(torch.tensor(logits)).numpy()
        preds = (probs >= 0.5).astype(int)
        macro_f1 = f1_score(labels, preds, average='macro', zero_division=0)
        micro_f1 = f1_score(labels, preds, average='micro', zero_division=0)
        return {'macro_f1': macro_f1, 'micro_f1': micro_f1}

    trainer = Trainer(
        model=model,
        args=args,
        train_dataset=train_ds,
        eval_dataset=val_ds,
        compute_metrics=compute_metrics,
    )

    start = time.time()
    trainer.train()
    train_time = time.time() - start

    val_outputs = trainer.predict(val_ds)
    val_probs = torch.sigmoid(torch.tensor(val_outputs.predictions)).numpy()
    thresholds = optimize_thresholds(val_labels, val_probs)

    test_outputs = trainer.predict(test_ds)
    test_probs = torch.sigmoid(torch.tensor(test_outputs.predictions)).numpy()
    preds = (test_probs >= thresholds).astype(int)

    macro_f1 = f1_score(test_labels, preds, average='macro', zero_division=0)
    micro_f1 = f1_score(test_labels, preds, average='micro', zero_division=0)
    try:
        auroc = roc_auc_score(test_labels, test_probs, average='macro')
    except Exception:
        auroc = float('nan')

    start_inf = time.perf_counter()
    _ = trainer.predict(test_ds.select(range(min(100, len(test_ds)))))
    inf_latency = time.perf_counter() - start_inf

    param_count = sum(p.numel() for p in model.parameters())

    transformer_results.append({
        'model': name,
        'macro_f1': macro_f1,
        'micro_f1': micro_f1,
        'auroc_macro': auroc,
        'train_time_s': train_time,
        'inference_100_s': inf_latency,
        'param_count': param_count,
    })
    transformer_artifacts[name] = {'model': model, 'tokenizer': tokenizer}

print(pd.DataFrame(transformer_results))


## G. Evaluation & Error Analysis

In [None]:
from sklearn.metrics import classification_report, precision_recall_fscore_support, confusion_matrix
import textwrap
import re
import matplotlib.pyplot as plt
import seaborn as sns

metrics_table = pd.DataFrame(baseline_results + transformer_results)
metrics_path = output_dir / 'metrics' / 'metrics_table.csv'
metrics_table.to_csv(metrics_path, index=False)
print('Wrote metrics table to', metrics_path)

# Per-label metrics for last transformer model
precision, recall, f1, _ = precision_recall_fscore_support(test_labels, preds, average=None, zero_division=0)
per_label_df = pd.DataFrame({
    'label': label_columns,
    'precision': precision,
    'recall': recall,
    'f1': f1,
})

plt.figure(figsize=(10, 4))
sns.barplot(data=per_label_df, x='label', y='f1')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
fig_path = output_dir / 'figures' / 'per_label_f1.png'
plt.savefig(fig_path)
plt.close()

# Confusion matrix for most prevalent label
prevalent_label = per_label_df.sort_values('f1', ascending=False)['label'].iloc[0]
label_idx = label_columns.index(prevalent_label)
cm = confusion_matrix(test_labels[:, label_idx], preds[:, label_idx])
plt.figure(figsize=(4, 4))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.title(f'Confusion Matrix: {prevalent_label}')
plt.xlabel('Predicted')
plt.ylabel('Actual')
cm_path = output_dir / 'figures' / 'confusion_matrix.png'
plt.tight_layout()
plt.savefig(cm_path)
plt.close()

print(classification_report(test_labels, preds, target_names=label_columns, zero_division=0))

# Error analysis: clipped excerpts only
NEGATION_PATTERN = re.compile(r"\b(no|denies|without|not|negative for)\b", re.IGNORECASE)
TEMPORAL_PATTERN = re.compile(r"\b(history of|previous|prior|formerly|last year)\b", re.IGNORECASE)
IMPLICIT_PATTERN = re.compile(r"\b(low income|food pantry|shelter|unstable housing)\b", re.IGNORECASE)


def categorize_error(text: str) -> str:
    if NEGATION_PATTERN.search(text):
        return 'negation'
    if TEMPORAL_PATTERN.search(text):
        return 'temporality'
    if IMPLICIT_PATTERN.search(text):
        return 'implicit'
    return 'ambiguity'

errors = []
for gold, pred, text in zip(test_labels, preds, test_texts):
    if not np.array_equal(gold, pred):
        errors.append({
            'text_excerpt': textwrap.shorten(text, width=280, placeholder='...'),
            'error_type': categorize_error(text),
            'gold': gold.tolist(),
            'pred': pred.tolist(),
        })
    if len(errors) >= 5:
        break

error_df = pd.DataFrame(errors)
error_df


## H. Computational Cost Assessment

In [None]:
import psutil

if importlib.util.find_spec('fvcore'):
    from fvcore.nn import FlopCountAnalysis
    has_fvcore = True
else:
    has_fvcore = False

process = psutil.Process(os.getpid())

cost_rows = []
for result in baseline_results + transformer_results:
    model_name = result['model']
    model_path = output_dir / 'models' / f'{model_name}.joblib'
    if not model_path.exists():
        model_path = output_dir / 'models_transformers' / model_name

    size_mb = None
    if model_path.exists():
        if model_path.is_file():
            size_mb = model_path.stat().st_size / 1e6
        else:
            size_mb = sum(p.stat().st_size for p in model_path.rglob('*') if p.is_file()) / 1e6

    flop_estimate = None
    if has_fvcore and model_name in transformer_artifacts:
        model = transformer_artifacts[model_name]['model']
        dummy = torch.randint(0, 100, (1, config.max_length))
        dummy_mask = torch.ones_like(dummy)
        try:
            flop_estimate = FlopCountAnalysis(model, {'input_ids': dummy, 'attention_mask': dummy_mask}).total()
        except Exception:
            flop_estimate = None

    cost_rows.append({
        'model': model_name,
        'train_time_s': result.get('train_time_s'),
        'inference_100_s': result.get('inference_100_s'),
        'param_count': result.get('param_count'),
        'peak_rss_mb': process.memory_info().rss / 1e6,
        'gpu_max_allocated_mb': torch.cuda.max_memory_allocated() / 1e6 if torch.cuda.is_available() else None,
        'model_size_mb': size_mb,
        'flops_estimate': flop_estimate,
    })

cost_table = pd.DataFrame(cost_rows)
cost_path = output_dir / 'cost' / 'cost_table.csv'
cost_table.to_csv(cost_path, index=False)
print('Wrote cost table to', cost_path)


## I. Reproducibility & Outputs

In [None]:
print('Outputs summary:')
for path in [
    summary_path,
    output_dir / 'splits' / 'train.csv',
    output_dir / 'splits' / 'val.csv',
    output_dir / 'splits' / 'test.csv',
    metrics_path,
    cost_path,
    fig_path,
    cm_path,
]:
    print('-', path.resolve())
