diff --git a/docs/api/models.rst b/docs/api/models.rst index 7368dec94..5dfaca9ff 100644 --- a/docs/api/models.rst +++ b/docs/api/models.rst @@ -204,3 +204,4 @@ API Reference models/pyhealth.models.TextEmbedding models/pyhealth.models.BIOT models/pyhealth.models.unified_multimodal_embedding_docs + models/pyhealth.models.califorest diff --git a/docs/api/models/pyhealth.models.califorest.rst b/docs/api/models/pyhealth.models.califorest.rst new file mode 100644 index 000000000..69ee1ff9b --- /dev/null +++ b/docs/api/models/pyhealth.models.califorest.rst @@ -0,0 +1,7 @@ +pyhealth.models.califorest +========================== + +.. automodule:: pyhealth.models.califorest + :members: + :undoc-members: + :show-inheritance: \ No newline at end of file diff --git a/examples/mimic4_califorest.py b/examples/mimic4_califorest.py new file mode 100644 index 000000000..01d26de2d --- /dev/null +++ b/examples/mimic4_califorest.py @@ -0,0 +1,190 @@ +from __future__ import annotations + +import os + +import numpy as np +import torch +from sklearn.ensemble import RandomForestClassifier +from sklearn.metrics import brier_score_loss, roc_auc_score + +from pyhealth.datasets import ( + MIMIC4EHRDataset, + create_sample_dataset, + get_dataloader, +) +from pyhealth.models import CaliForest +from pyhealth.tasks import InHospitalMortalityMIMIC4 + + +# Set your MIMIC-IV dataset path via environment variable before running: +# export MIMIC4_ROOT=/your/path/to/mimiciv/3.1 +ROOT = os.getenv("MIMIC4_ROOT") + + +def evaluate(y_true: np.ndarray, y_prob: np.ndarray) -> dict[str, float]: + """Compute AUROC and Brier score.""" + y_true = np.asarray(y_true).reshape(-1) + y_prob = np.asarray(y_prob).reshape(-1) + return { + "auroc": float(roc_auc_score(y_true, y_prob)), + "brier": float(brier_score_loss(y_true, y_prob)), + } + + +def run_califorest( + X_train: np.ndarray, + y_train: np.ndarray, + X_test: np.ndarray, + y_test: np.ndarray, + calibration: str, +) -> dict[str, float]: + """Train and evaluate CaliForest on tabularized features.""" + train_samples = [] + for i in range(len(X_train)): + train_samples.append( + { + "patient_id": f"train-{i}", + "visit_id": f"train-{i}", + "features": X_train[i].tolist(), + "label": int(y_train[i]), + } + ) + + test_samples = [] + for i in range(len(X_test)): + test_samples.append( + { + "patient_id": f"test-{i}", + "visit_id": f"test-{i}", + "features": X_test[i].tolist(), + "label": int(y_test[i]), + } + ) + + train_dataset = create_sample_dataset( + samples=train_samples, + input_schema={"features": "tensor"}, + output_schema={"label": "binary"}, + dataset_name=f"mimic4_train_tabular_{calibration}", + ) + test_dataset = create_sample_dataset( + samples=test_samples, + input_schema={"features": "tensor"}, + output_schema={"label": "binary"}, + dataset_name=f"mimic4_test_tabular_{calibration}", + ) + + train_loader = get_dataloader( + train_dataset, batch_size=len(train_dataset), shuffle=False + ) + test_loader = get_dataloader( + test_dataset, batch_size=len(test_dataset), shuffle=False + ) + + test_batch = next(iter(test_loader)) + + model = CaliForest( + dataset=train_dataset, + n_estimators=100, + calibration=calibration, + random_state=42, + ) + model.fit(train_loader) + + with torch.no_grad(): + ret = model(**test_batch) + + cali_probs = ret["y_prob"].detach().cpu().numpy().reshape(-1) + return evaluate(y_test, cali_probs) + + +def main(): + if not ROOT: + raise ValueError( + "MIMIC4_ROOT is not set. Example:\n" + "export MIMIC4_ROOT=/your/path/to/mimiciv/3.1" + ) + + print("=" * 80) + print("Loading MIMIC-IV EHR dataset") + print("=" * 80) + + dataset = MIMIC4EHRDataset( + root=ROOT, + tables=["diagnoses_icd", "procedures_icd", "labevents"], + ) + + task = InHospitalMortalityMIMIC4() + sample_dataset = dataset.set_task(task) + + print(f"Total samples: {len(sample_dataset)}") + + subset_size = 2000 + raw_subset_samples = [sample_dataset[i] for i in range(subset_size)] + + clean_subset_samples = [] + for sample in raw_subset_samples: + clean_subset_samples.append( + { + "patient_id": str(sample["patient_id"]), + "visit_id": str(sample["admission_id"]), + "labs": sample["labs"].tolist(), + "mortality": int(sample["mortality"].item()), + } + ) + + subset_dataset = create_sample_dataset( + samples=clean_subset_samples, + input_schema={"labs": "tensor"}, + output_schema={"mortality": "binary"}, + dataset_name="mimic4_mortality_subset", + ) + + loader = get_dataloader(subset_dataset, batch_size=subset_size, shuffle=False) + batch = next(iter(loader)) + + X = batch["labs"].detach().cpu().numpy() + y = batch["mortality"].detach().cpu().numpy().reshape(-1) + + X = X.reshape(X.shape[0], -1) + + print("Flattened feature matrix:", X.shape) + print("Labels:", y.shape) + + split = int(0.8 * len(X)) + X_train, X_test = X[:split], X[split:] + y_train, y_test = y[:split], y[split:] + + print("=" * 80) + print("Baseline Random Forest") + print("=" * 80) + + rf = RandomForestClassifier( + n_estimators=100, + random_state=42, + bootstrap=True, + ) + rf.fit(X_train, y_train) + rf_probs = rf.predict_proba(X_test)[:, 1] + rf_metrics = evaluate(y_test, rf_probs) + print("RF metrics:", rf_metrics) + + print("=" * 80) + print("CaliForest (isotonic calibration)") + print("=" * 80) + isotonic_metrics = run_califorest( + X_train, y_train, X_test, y_test, calibration="isotonic" + ) + print("CaliForest isotonic metrics:", isotonic_metrics) + + print("=" * 80) + print("CaliForest (logistic calibration)") + print("=" * 80) + logistic_metrics = run_califorest( + X_train, y_train, X_test, y_test, calibration="logistic" + ) + print("CaliForest logistic metrics:", logistic_metrics) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index 5233b1726..59785e7a3 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -44,3 +44,4 @@ from .sdoh import SdohClassifier from .medlink import MedLink from .unified_embedding import UnifiedMultimodalEmbeddingModel, SinusoidalTimeEmbedding +from .califorest import CaliForest \ No newline at end of file diff --git a/pyhealth/models/califorest.py b/pyhealth/models/califorest.py new file mode 100644 index 000000000..bdc539eaf --- /dev/null +++ b/pyhealth/models/califorest.py @@ -0,0 +1,233 @@ +""" +Author: Kobe Guo +NetID: kobeg2 + +Paper: CaliForest: Calibrated Random Forests for Healthcare Prediction +Link: https://joyceho.github.io/assets/pdf/paper/park-chil20.pdf + +Description: +Implementation of CaliForest, a calibrated random forest model that applies +post-hoc calibration (isotonic or logistic) to improve probability estimates +for healthcare prediction tasks. +""" + +from __future__ import annotations + +from typing import Dict, List, Optional + +import numpy as np +import torch +import torch.nn as nn +from sklearn.ensemble import RandomForestClassifier +from sklearn.isotonic import IsotonicRegression +from sklearn.linear_model import LogisticRegression + +from pyhealth.datasets import SampleDataset +from pyhealth.models import BaseModel + + +class CaliForest(BaseModel): + """CaliForest model for calibrated probability prediction. + + This model wraps a RandomForestClassifier and applies a post-hoc + calibration step using out-of-bag (OOB) predictions and prediction + variance to improve probability estimates. + + Important: + CaliForest is fit once on the full training set using fit(train_loader). + After fitting, forward() should be used only for inference/evaluation. + This implementation currently supports binary classification only. + + The overall procedure is: + 1. train a random forest classifier, + 2. compute OOB probabilities for each training sample, + 3. estimate prediction uncertainty using variance across tree outputs, + 4. fit a calibration model using uncertainty-weighted samples. + + Args: + dataset: the dataset used to initialize feature and label schemas. + n_estimators: number of trees in the random forest. Default is 100. + max_depth: maximum depth of each tree. Default is None. + calibration: calibration method. Supported values are ``"isotonic"`` + and ``"logistic"``. Default is ``"isotonic"``. + random_state: random seed for reproducibility. Default is 42. + **kwargs: additional compatibility arguments. + + Example: + model = CaliForest(dataset=dataset, n_estimators=10) + model.fit(train_loader) + ret = model(**batch) + print(ret["y_prob"].shape) + """ + + def __init__( + self, + dataset: SampleDataset, + n_estimators: int = 100, + max_depth: Optional[int] = None, + calibration: str = "isotonic", + random_state: int = 42, + **kwargs, + ): + super(CaliForest, self).__init__(dataset) + + assert len(self.label_keys) == 1, "Only one label key is supported" + self.label_key = self.label_keys[0] + + self.n_estimators = n_estimators + self.max_depth = max_depth + self.calibration = calibration + self.random_state = random_state + + if self.calibration not in {"isotonic", "logistic"}: + raise ValueError(f"Unsupported calibration: {self.calibration}") + + self.rf = RandomForestClassifier( + n_estimators=self.n_estimators, + max_depth=self.max_depth, + bootstrap=True, + oob_score=True, + random_state=self.random_state, + ) + + self.calibrator = None + self.is_fitted = False + + + def _build_feature_matrix(self, **kwargs) -> np.ndarray: + """Convert PyHealth batch into NumPy feature matrix.""" + features: List[np.ndarray] = [] + + for key in self.feature_keys: + x = kwargs[key] + + if isinstance(x, torch.Tensor): + arr = x.detach().cpu().numpy() + else: + arr = np.asarray(x) + + if arr.ndim == 1: + arr = arr.reshape(-1, 1) + elif arr.ndim > 2: + arr = arr.reshape(arr.shape[0], -1) + + features.append(arr.astype(np.float32)) + + return np.concatenate(features, axis=1) + + def _build_labels(self, **kwargs) -> np.ndarray: + y = kwargs[self.label_key] + if isinstance(y, torch.Tensor): + y = y.detach().cpu().numpy() + else: + y = np.asarray(y) + return y.reshape(-1) + + def fit(self, train_loader): + """Fit CaliForest on the full training dataloader""" + X_list = [] + y_list = [] + + for batch in train_loader: + X_list.append(self._build_feature_matrix(**batch)) + y_list.append(self._build_labels(**batch)) + + X = np.concatenate(X_list, axis=0) + y = np.concatenate(y_list, axis=0) + + self.fit_model(features=X, labels=y) + return self + + def fit_model(self, **kwargs) -> None: + """Fit RF + calibration model.""" + if "features" in kwargs and "labels" in kwargs: + X = kwargs["features"] + y = kwargs["labels"] + else: + X = self._build_feature_matrix(**kwargs) + y = self._build_labels(**kwargs) + + unique_labels = np.unique(y) + if set(unique_labels.tolist()) != {0, 1}: + raise ValueError( + "CaliForest currently supports binary classification only. " + f"Got labels: {unique_labels.tolist()}" + ) + self.rf.fit(X, y) + + if not hasattr(self.rf, "oob_decision_function_"): + raise RuntimeError("OOB predictions not available.") + + oob_probs = self.rf.oob_decision_function_[:, 1] + + tree_probs = np.stack( + [t.predict_proba(X)[:, 1] for t in self.rf.estimators_], + axis=0, + ) + variances = np.var(tree_probs, axis=0) + + # CaliForest uses inverse tree-level variance so more stable + # predictions have greater influence during calibrator fitting. + weights = 1.0 / (variances + 1e-6) + + if self.calibration == "isotonic": + calibrator = IsotonicRegression(out_of_bounds="clip") + calibrator.fit(oob_probs, y, sample_weight=weights) + self.calibrator = calibrator + else: + calibrator = LogisticRegression() + calibrator.fit( + oob_probs.reshape(-1, 1), + y, + sample_weight=weights, + ) + self.calibrator = calibrator + + self.is_fitted = True + + def predict_proba_numpy(self, **kwargs) -> np.ndarray: + """Predict calibrated probabilities.""" + if not self.is_fitted: + raise RuntimeError("Model must be fitted first.") + + X = self._build_feature_matrix(**kwargs) + rf_probs = self.rf.predict_proba(X)[:, 1] + + if self.calibration == "isotonic": + calibrated = self.calibrator.predict(rf_probs) + else: + calibrated = self.calibrator.predict_proba( + rf_probs.reshape(-1, 1) + )[:, 1] + + return calibrated.reshape(-1, 1) + + def forward(self, **kwargs) -> Dict[str, torch.Tensor]: + """PyHealth forward pass.""" + if not self.is_fitted: + raise RuntimeError( + "CaliForest must be fitted before inference. " + "Call model.fit(train_loader) first." + ) + + y_prob_np = self.predict_proba_numpy(**kwargs) + + y_prob = torch.tensor( + y_prob_np, dtype=torch.float32, device=self.device + ) + + eps = 1e-6 + logits = torch.log( + torch.clamp(y_prob, eps, 1 - eps) / + torch.clamp(1 - y_prob, eps, 1 - eps) + ) + + y_true = kwargs[self.label_key].to(self.device) + loss = self.get_loss_function()(logits, y_true) + + return { + "loss": loss, + "y_prob": y_prob, + "y_true": y_true, + "logit": logits, + } \ No newline at end of file diff --git a/tests/core/test_califorest.py b/tests/core/test_califorest.py new file mode 100644 index 000000000..58d89a8be --- /dev/null +++ b/tests/core/test_califorest.py @@ -0,0 +1,150 @@ +import unittest +import torch + +from pyhealth.datasets import create_sample_dataset, get_dataloader +from pyhealth.models import CaliForest + + +class TestCaliForest(unittest.TestCase): + """Test cases for the CaliForest model.""" + + def setUp(self): + """Set up synthetic data, dataset, and model.""" + self.samples = [ + { + "patient_id": "patient-0", + "visit_id": "visit-0", + "features": [1.0, 2.0, 3.0, 4.0], + "label": 0, + }, + { + "patient_id": "patient-1", + "visit_id": "visit-1", + "features": [2.0, 1.5, 0.5, 3.0], + "label": 1, + }, + { + "patient_id": "patient-2", + "visit_id": "visit-2", + "features": [0.5, 0.7, 1.2, 1.8], + "label": 0, + }, + { + "patient_id": "patient-3", + "visit_id": "visit-3", + "features": [3.1, 2.9, 4.0, 1.2], + "label": 1, + }, + ] + + self.input_schema = {"features": "tensor"} + self.output_schema = {"label": "binary"} + + self.dataset = create_sample_dataset( + samples=self.samples, + input_schema=self.input_schema, + output_schema=self.output_schema, + dataset_name="califorest_test", + ) + + self.model = CaliForest( + dataset=self.dataset, + n_estimators=10, + calibration="isotonic", + random_state=42, + ) + + self.loader = get_dataloader(self.dataset, batch_size=4, shuffle=False) + + def test_model_initialization(self): + """Test that the model initializes correctly.""" + self.assertIsInstance(self.model, CaliForest) + self.assertEqual(self.model.n_estimators, 10) + self.assertEqual(self.model.calibration, "isotonic") + self.assertEqual(self.model.label_key, "label") + self.assertFalse(self.model.is_fitted) + + def test_model_forward(self): + """Test that forward pass works and returns expected keys.""" + batch = next(iter(self.loader)) + self.model.fit(self.loader) + + with torch.no_grad(): + ret = self.model(**batch) + + self.assertIn("loss", ret) + self.assertIn("y_prob", ret) + self.assertIn("y_true", ret) + self.assertIn("logit", ret) + + self.assertEqual(ret["y_prob"].shape, (4, 1)) + self.assertEqual(ret["y_true"].shape, (4, 1)) + self.assertEqual(ret["logit"].shape, (4, 1)) + self.assertEqual(ret["loss"].dim(), 0) + + def test_probability_range(self): + """Test that predicted probabilities are in [0, 1].""" + batch = next(iter(self.loader)) + self.model.fit(self.loader) + + with torch.no_grad(): + ret = self.model(**batch) + + y_prob = ret["y_prob"] + self.assertTrue(torch.all(y_prob >= 0.0).item()) + self.assertTrue(torch.all(y_prob <= 1.0).item()) + + def test_forward_before_fit_raises(self): + """Test that calling forward before fit raises a clear error.""" + batch = next(iter(self.loader)) + + with self.assertRaises(RuntimeError): + self.model(**batch) + + def test_logistic_calibration(self): + """Test the logistic calibration option.""" + model = CaliForest( + dataset=self.dataset, + n_estimators=10, + calibration="logistic", + random_state=42, + ) + + batch = next(iter(self.loader)) + model.fit(self.loader) + + with torch.no_grad(): + ret = model(**batch) + + self.assertIn("y_prob", ret) + self.assertEqual(ret["y_prob"].shape, (4, 1)) + + def test_isotonic_and_logistic_differ(self): + """Test that isotonic and logistic calibration produce different outputs.""" + iso_model = CaliForest( + dataset=self.dataset, + n_estimators=10, + calibration="isotonic", + random_state=42, + ) + log_model = CaliForest( + dataset=self.dataset, + n_estimators=10, + calibration="logistic", + random_state=42, + ) + + iso_model.fit(self.loader) + log_model.fit(self.loader) + + batch = next(iter(self.loader)) + + with torch.no_grad(): + iso_probs = iso_model(**batch)["y_prob"] + log_probs = log_model(**batch)["y_prob"] + + self.assertFalse(torch.allclose(iso_probs, log_probs)) + + +if __name__ == "__main__": + unittest.main()