diff --git a/pyhealth/calib/README.md b/pyhealth/calib/README.md index 37e391635..05d1989aa 100644 --- a/pyhealth/calib/README.md +++ b/pyhealth/calib/README.md @@ -14,6 +14,8 @@ Model calibration methods adjust predicted probabilities to better reflect true Temperature scaling (also known as Platt scaling for binary classification) is a simple yet effective calibration method that scales logits by a learned temperature parameter. +**Guarantee**: Empirically reduces Expected Calibration Error (ECE). No formal finite-sample statistical guarantee, but widely effective in practice for improving probability calibration. + **Reference**: - Guo, Chuan, Geoff Pleiss, Yu Sun, and Kilian Q. Weinberger. "On calibration of modern neural networks." ICML 2017. @@ -25,6 +27,8 @@ Temperature scaling (also known as Platt scaling for binary classification) is a Histogram binning is a non-parametric calibration method that bins predictions and adjusts probabilities within each bin. +**Guarantee**: Asymptotically consistent calibration as calibration set size → ∞. Provides better empirical calibration (lower ECE) than uncalibrated models. For top-label calibration, provides distribution-free top-label calibration guarantees. + **References**: - Zadrozny, Bianca, and Charles Elkan. "Learning and making decisions when costs and probabilities are both unknown." In Proceedings of the seventh ACM SIGKDD international conference on Knowledge discovery and data mining, pp. 204-213. 2001. - Gupta, Chirag, and Aaditya Ramdas. "Top-label calibration and multiclass-to-binary reductions." ICLR 2022. @@ -37,6 +41,8 @@ Histogram binning is a non-parametric calibration method that bins predictions a Dirichlet calibration learns a matrix transformation of logits with regularization for improved calibration. +**Guarantee**: More expressive than temperature scaling. Empirically reduces multiclass calibration error (ECE, classwise-ECE) by learning class-specific transformations. Optimizes log-likelihood under Dirichlet prior. + **Reference**: - Kull, Meelis, Miquel Perello Nieto, Markus Kängsepp, Telmo Silva Filho, Hao Song, and Peter Flach. "Beyond temperature scaling: Obtaining well-calibrated multi-class probabilities with dirichlet calibration." NeurIPS 2019. @@ -48,12 +54,34 @@ Dirichlet calibration learns a matrix transformation of logits with regularizati KCal uses kernel density estimation on embeddings for full multiclass calibration. The model must support `embed=True` in forward pass. +**Guarantee**: Leverages learned representations for calibration. Empirically reduces ECE, particularly effective when embedding space captures semantic structure. Provides non-parametric calibration through kernel density estimation in embedding space. + **Reference**: - Lin, Zhen, Shubhendu Trivedi, and Jimeng Sun. "Taking a Step Back with KCal: Multi-Class Kernel-Based Calibration for Deep Neural Networks." ICLR 2023. -## Prediction Set Methods +## Prediction Set Methods (Conformal Prediction) + +Conformal prediction is a framework for constructing prediction sets with formal statistical coverage guarantees. Instead of producing a single prediction, these methods output a set of plausible labels that is guaranteed to contain the true label with a user-specified probability (e.g., 90% or 95%). The key advantage is that these guarantees are **distribution-free** and hold for finite samples without assumptions on the data distribution or model—only requiring that calibration and test data are exchangeable (i.e., drawn from the same distribution). + +For example, with α=0.1 (90% coverage), conformal prediction guarantees that P(Y ∈ C(X)) ≥ 0.9, where C(X) is the prediction set for input X. This is particularly valuable in high-stakes applications like healthcare, where quantifying uncertainty is critical for safe decision-making. + +### BaseConformal (Base Split Conformal Prediction) + +**Modes**: `multiclass` + +**Class**: `pyhealth.calib.predictionset.BaseConformal` + +BaseConformal implements standard split conformal prediction without covariate shift correction. It provides a clean baseline implementation for constructing prediction sets with distribution-free coverage guarantees by calibrating score thresholds on a held-out calibration set. + +**Guarantee**: Distribution-free finite-sample coverage under exchangeability: +- **Marginal**: P(Y ∉ C(X)) ≤ α (with high probability) +- **Class-conditional**: P(Y ∉ C(X) | Y=k) ≤ α_k for each class k -Prediction set methods provide set-valued predictions with statistical coverage guarantees. +No assumptions on the model or data distribution required (only exchangeability). + +**References**: +- Vovk, Vladimir, Alexander Gammerman, and Glenn Shafer. "Algorithmic learning in a random world." Springer, 2005. +- Lei, Jing, et al. "Distribution-free predictive inference for regression." Journal of the American Statistical Association (2018). ### LABEL (Least Ambiguous Set-valued Classifier) @@ -63,6 +91,12 @@ Prediction set methods provide set-valued predictions with statistical coverage LABEL is a conformal prediction method that constructs prediction sets with bounded error levels. Supports both marginal and class-conditional coverage. +**Guarantee**: Distribution-free finite-sample coverage guarantees: +- **Marginal**: P(Y ∉ C(X)) ≤ α +- **Class-conditional**: P(Y ∉ C(X) | Y=k) ≤ α_k for each class k + +Constructs least ambiguous (minimal size) sets subject to coverage constraints. Similar to BaseConformal but optimized for minimal ambiguity. + **Reference**: - Sadinle, Mauricio, Jing Lei, and Larry Wasserman. "Least ambiguous set-valued classifiers with bounded error levels." Journal of the American Statistical Association 114, no. 525 (2019): 223-234. @@ -74,6 +108,12 @@ LABEL is a conformal prediction method that constructs prediction sets with boun SCRIB controls class-specific risk while minimizing prediction set ambiguity through optimized class-specific thresholds. +**Guarantee**: Class-specific risk control with minimal ambiguity: +- **Overall**: P(Y ∉ C(X) | |C(X)|=1) ≤ risk (error rate on singleton predictions) +- **Class-specific**: P(Y ∉ C(X) | Y=k, |C(X)|=1) ≤ risk_k for each class k + +Optimizes class-specific thresholds via coordinate descent to minimize prediction set ambiguity while respecting risk bounds. + **Reference**: - Lin, Zhen, Lucas Glass, M. Brandon Westover, Cao Xiao, and Jimeng Sun. "SCRIB: Set-classifier with Class-specific Risk Bounds for Blackbox Models." AAAI 2022. @@ -85,6 +125,12 @@ SCRIB controls class-specific risk while minimizing prediction set ambiguity thr FavMac constructs prediction sets that maximize value while controlling cost/risk, particularly useful for multilabel classification with asymmetric costs. +**Guarantee**: Conformal cost control with value maximization: +- **Expected cost**: E[Cost(C(X), Y)] ≤ target_cost (in expectation over calibration) +- **Adaptive thresholds**: Dynamically adjusts thresholds online to control false positive rates + +Particularly useful for multilabel tasks with asymmetric costs (e.g., medical diagnosis where false positives/negatives have different costs). + **References**: - Lin, Zhen, Shubhendu Trivedi, Cao Xiao, and Jimeng Sun. "Fast Online Value-Maximizing Prediction Sets with Conformal Cost Control (FavMac)." ICML 2023. - Fisch, Adam, Tal Schuster, Tommi Jaakkola, and Regina Barzilay. "Conformal prediction sets with limited false positives." ICML 2022. @@ -95,10 +141,17 @@ FavMac constructs prediction sets that maximize value while controlling cost/ris **Class**: `pyhealth.calib.predictionset.CovariateLabel` -CovariateLabel extends LABEL to handle covariate shift between calibration and test distributions using likelihood ratio weighting. +CovariateLabel extends LABEL to handle covariate shift between calibration and test distributions using likelihood ratio weighting. The default KDE-based approach follows the CoDrug method, which uses kernel density estimation on embeddings to compute likelihood ratios. Users can also provide custom weights for flexibility. -**Reference**: -- Tibshirani, Ryan J., Rina Foygel Barber, Emmanuel Candes, and Aaditya Ramdas. "Conformal prediction under covariate shift." NeurIPS 2019. +**Guarantee**: Distribution-free coverage under covariate shift: +- **Marginal**: P_test(Y ∉ C(X)) ≤ α on test distribution +- **Class-conditional**: P_test(Y ∉ C(X) | Y=k) ≤ α_k on test distribution + +Uses importance weighting (likelihood ratios w(x) = p_test(x)/p_cal(x)) to correct for distribution shift between calibration and test sets. Valid when weights are well-estimated. Supports KDE-based automatic weighting (CoDrug) or custom user-provided weights. + +**References**: +- Tibshirani, Ryan J., Rina Foygel Barber, Emmanuel Candes, and Aaditya Ramdas. "Conformal prediction under covariate shift." NeurIPS 2019. https://arxiv.org/abs/1904.06019 +- Laghuvarapu, Siddhartha, Zhen Lin, and Jimeng Sun. "Conformal Drug Property Prediction with Density Estimation under Covariate Shift." NeurIPS 2023. https://arxiv.org/abs/2310.12033 ## Usage diff --git a/pyhealth/calib/predictionset/__init__.py b/pyhealth/calib/predictionset/__init__.py index 593ae27ef..44264b80e 100644 --- a/pyhealth/calib/predictionset/__init__.py +++ b/pyhealth/calib/predictionset/__init__.py @@ -1,8 +1,9 @@ """Prediction set construction methods""" +from pyhealth.calib.predictionset.base_conformal import BaseConformal from pyhealth.calib.predictionset.covariate import CovariateLabel from pyhealth.calib.predictionset.favmac import FavMac from pyhealth.calib.predictionset.label import LABEL from pyhealth.calib.predictionset.scrib import SCRIB -__all__ = ["LABEL", "SCRIB", "FavMac", "CovariateLabel"] +__all__ = ["BaseConformal", "LABEL", "SCRIB", "FavMac", "CovariateLabel"] diff --git a/pyhealth/calib/predictionset/base_conformal/__init__.py b/pyhealth/calib/predictionset/base_conformal/__init__.py new file mode 100644 index 000000000..b1ceefdee --- /dev/null +++ b/pyhealth/calib/predictionset/base_conformal/__init__.py @@ -0,0 +1,289 @@ +""" +Base Conformal Prediction (Split Conformal) + +Standard split conformal prediction for multiclass classification without +covariate shift correction. + +This method constructs prediction sets with coverage guarantees by calibrating +score thresholds on a held-out calibration set. + +Paper: + Vovk, Vladimir, Alexander Gammerman, and Glenn Shafer. + "Algorithmic learning in a random world." Springer, 2005. + + Papadopoulos, Harris, Kostas Proedrou, Volodya Vovk, and Alex Gammerman. + "Inductive confidence machines for regression." ECML 2002. +""" + +from typing import Dict, Union + +import numpy as np +import torch +from torch.utils.data import IterableDataset + +from pyhealth.calib.base_classes import SetPredictor +from pyhealth.calib.utils import prepare_numpy_dataset +from pyhealth.models import BaseModel + +__all__ = ["BaseConformal"] + + +def _query_quantile(scores: np.ndarray, alpha: float) -> float: + """Compute the alpha-quantile of scores for conformal prediction. + + Args: + scores: Array of conformity scores + alpha: Quantile level (between 0 and 1), typically the miscoverage rate + + Returns: + The alpha-quantile of scores + """ + scores = np.sort(scores) + N = len(scores) + # Use ceiling to get conservative coverage + loc = int(np.ceil(alpha * (N + 1))) - 1 + return -np.inf if loc == -1 else scores[loc] + + +class BaseConformal(SetPredictor): + """Base Conformal Prediction for multiclass classification. + + This implements standard split conformal prediction, which constructs + prediction sets with distribution-free coverage guarantees. The method + calibrates thresholds on a calibration set and uses them to construct + prediction sets on test data. + + The method guarantees that: + - For marginal coverage (alpha is float): P(Y not in C(X)) <= alpha + - For class-conditional coverage (alpha is array): P(Y not in C(X) | Y=k) <= alpha[k] + + where C(X) denotes the prediction set for input X. + + Papers: + Vovk, Vladimir, Alexander Gammerman, and Glenn Shafer. + "Algorithmic learning in a random world." Springer, 2005. + + Lei, Jing, Max G'Sell, Alessandro Rinaldo, Ryan J. Tibshirani, + and Larry Wasserman. "Distribution-free predictive inference for + regression." Journal of the American Statistical Association (2018). + + Args: + model: A trained base model that outputs predicted probabilities + alpha: Target miscoverage rate(s). Can be: + - float: marginal coverage P(Y not in C(X)) <= alpha + - array: class-conditional P(Y not in C(X) | Y=k) <= alpha[k] + score_type: Type of conformity score to use. Options: + - "aps": Adaptive Prediction Sets (default, uses probability scores) + - "threshold": Simple threshold on probabilities + debug: Whether to use debug mode (processes fewer samples) + + Examples: + >>> from pyhealth.datasets import ISRUCDataset, split_by_patient + >>> from pyhealth.datasets import get_dataloader + >>> from pyhealth.models import SparcNet + >>> from pyhealth.tasks import sleep_staging_isruc_fn + >>> from pyhealth.calib.predictionset.base_conformal import BaseConformal + >>> + >>> # Prepare data + >>> sleep_ds = ISRUCDataset("/data/ISRUC-I").set_task( + ... sleep_staging_isruc_fn) + >>> train_data, val_data, test_data = split_by_patient( + ... sleep_ds, [0.6, 0.2, 0.2]) + >>> + >>> # Train model + >>> model = SparcNet(dataset=sleep_ds, feature_keys=["signal"], + ... label_key="label", mode="multiclass") + >>> # ... training code ... + >>> + >>> # Create conformal predictor with marginal coverage + >>> conformal_model = BaseConformal(model, alpha=0.1) + >>> conformal_model.calibrate(cal_dataset=val_data) + >>> + >>> # Evaluate + >>> test_dl = get_dataloader(test_data, batch_size=32, shuffle=False) + >>> from pyhealth.trainer import Trainer, get_metrics_fn + >>> y_true_all, y_prob_all, _, extra_output = Trainer( + ... model=conformal_model).inference(test_dl, + ... additional_outputs=['y_predset']) + >>> print(get_metrics_fn(conformal_model.mode)( + ... y_true_all, y_prob_all, + ... metrics=['accuracy', 'miscoverage_ps', 'avg_set_size'], + ... y_predset=extra_output['y_predset'])) + {'accuracy': 0.71, 'miscoverage_ps': 0.095, 'avg_set_size': 1.8} + >>> + >>> # Class-conditional coverage + >>> conformal_model_cc = BaseConformal( + ... model, alpha=[0.1, 0.15, 0.1, 0.1, 0.1]) + >>> conformal_model_cc.calibrate(cal_dataset=val_data) + """ + + def __init__( + self, + model: BaseModel, + alpha: Union[float, np.ndarray], + score_type: str = "aps", + debug: bool = False, + **kwargs, + ) -> None: + super().__init__(model, **kwargs) + + if model.mode != "multiclass": + raise NotImplementedError( + "BaseConformal only supports multiclass classification" + ) + + self.mode = self.model.mode + + # Freeze model parameters + for param in model.parameters(): + param.requires_grad = False + self.model.eval() + + self.device = model.device + self.debug = debug + self.score_type = score_type + + # Store alpha + if not isinstance(alpha, float): + alpha = np.asarray(alpha) + self.alpha = alpha + + # Will be set during calibration + self.t = None + + def _compute_conformity_scores( + self, y_prob: np.ndarray, y_true: np.ndarray + ) -> np.ndarray: + """Compute conformity scores from predictions and true labels. + + Args: + y_prob: Predicted probabilities of shape (N, K) + y_true: True class labels of shape (N,) + + Returns: + Conformity scores of shape (N,) + """ + N = len(y_true) + if self.score_type == "aps" or self.score_type == "threshold": + # Use probability of true class as conformity score + # Higher score = more conforming (better prediction) + scores = y_prob[np.arange(N), y_true] + else: + raise ValueError(f"Unknown score_type: {self.score_type}") + + return scores + + def calibrate(self, cal_dataset: IterableDataset): + """Calibrate the thresholds for prediction set construction. + + Args: + cal_dataset: Calibration set (held-out validation data) + """ + # Get predictions and true labels + cal_dataset_dict = prepare_numpy_dataset( + self.model, + cal_dataset, + ["y_prob", "y_true"], + debug=self.debug, + ) + + y_prob = cal_dataset_dict["y_prob"] + y_true = cal_dataset_dict["y_true"] + N, K = y_prob.shape + + # Compute conformity scores + conformity_scores = self._compute_conformity_scores(y_prob, y_true) + + # Compute quantile thresholds + if isinstance(self.alpha, float): + # Marginal coverage: single threshold + t = _query_quantile(conformity_scores, self.alpha) + else: + # Class-conditional coverage: one threshold per class + if len(self.alpha) != K: + raise ValueError( + f"alpha must have length {K} for class-conditional " + f"coverage, got {len(self.alpha)}" + ) + t = [] + for k in range(K): + mask = y_true == k + if np.sum(mask) > 0: + class_scores = conformity_scores[mask] + t_k = _query_quantile(class_scores, self.alpha[k]) + else: + # If no calibration examples, use -inf (include all) + print( + f"Warning: No calibration examples for class {k}, " + "using -inf threshold" + ) + t_k = -np.inf + t.append(t_k) + + self.t = torch.tensor(t, device=self.device) + + if self.debug: + print(f"Calibrated thresholds: {self.t}") + + def forward(self, **kwargs) -> Dict[str, torch.Tensor]: + """Forward propagation with prediction set construction. + + Returns: + Dictionary with all results from base model, plus: + - y_predset: Boolean tensor indicating which classes + are in the prediction set + """ + if self.t is None: + raise RuntimeError( + "Model must be calibrated before inference. " + "Call calibrate() first." + ) + + pred = self.model(**kwargs) + + # Construct prediction set by thresholding probabilities + # Include classes with probability >= threshold + pred["y_predset"] = pred["y_prob"] >= self.t + + return pred + + +if __name__ == "__main__": + # Example usage + from pyhealth.datasets import ISRUCDataset, split_by_patient, get_dataloader + from pyhealth.models import SparcNet + from pyhealth.tasks import sleep_staging_isruc_fn + + sleep_ds = ISRUCDataset("/srv/local/data/trash", dev=True).set_task( + sleep_staging_isruc_fn + ) + train_data, val_data, test_data = split_by_patient(sleep_ds, [0.6, 0.2, 0.2]) + + model = SparcNet( + dataset=sleep_ds, + feature_keys=["signal"], + label_key="label", + mode="multiclass", + ) + + # Marginal coverage + conformal_model = BaseConformal(model, alpha=0.1) + conformal_model.calibrate(cal_dataset=val_data) + + # Evaluate + from pyhealth.trainer import Trainer, get_metrics_fn + + test_dl = get_dataloader(test_data, batch_size=32, shuffle=False) + y_true_all, y_prob_all, _, extra_output = Trainer( + model=conformal_model + ).inference(test_dl, additional_outputs=["y_predset"]) + + print( + get_metrics_fn(conformal_model.mode)( + y_true_all, + y_prob_all, + metrics=["accuracy", "miscoverage_ps"], + y_predset=extra_output["y_predset"], + ) + ) + diff --git a/pyhealth/calib/predictionset/covariate/covariate_label.py b/pyhealth/calib/predictionset/covariate/covariate_label.py index 602944b15..f90430be8 100644 --- a/pyhealth/calib/predictionset/covariate/covariate_label.py +++ b/pyhealth/calib/predictionset/covariate/covariate_label.py @@ -2,12 +2,24 @@ Covariate Shift Adaptive Conformal Prediction. This module implements conformal prediction with covariate shift correction -using likelihood ratio weighting based on density estimation. +using likelihood ratio weighting. The implementation supports both: +1. KDE-based density estimation for automatic weight computation +2. User-provided custom weights for flexibility -Paper: +The KDE-based correction approach is based on the CoDrug method, which uses +energy-based models and kernel density estimation to assess molecular densities +and construct weighted conformal prediction sets. + +Papers: Tibshirani, Ryan J., Rina Foygel Barber, Emmanuel Candes, and Aaditya Ramdas. "Conformal prediction under covariate shift." Advances in neural information processing systems 32 (2019). + https://arxiv.org/abs/1904.06019 + + Laghuvarapu, Siddhartha, Zhen Lin, and Jimeng Sun. + "Conformal Drug Property Prediction with Density Estimation under + Covariate Shift." NeurIPS 2023. + https://arxiv.org/abs/2310.12033 """ from typing import Callable, Dict, Optional, Union @@ -33,9 +45,19 @@ def fit_kde( ) -> tuple[Callable, Callable]: """Fit KDEs on calibration and test embeddings using PyHealth's KDE. + This implements the KDE-based density estimation approach from the CoDrug + paper (Laghuvarapu et al., NeurIPS 2023) for computing likelihood ratios + under covariate shift. The method uses kernel density estimation on both + calibration and test embeddings to estimate p_test(x) / p_cal(x). + This uses the PyHealth torch-based RBF kernel density estimator which is more efficient than sklearn for GPU computation. + Reference: + Laghuvarapu, S., Lin, Z., & Sun, J. (2023). Conformal Drug Property + Prediction with Density Estimation under Covariate Shift. NeurIPS 2023. + https://arxiv.org/abs/2310.12033 + Args: cal_embeddings: Calibration embeddings as numpy array of shape (n_cal_samples, embedding_dim) @@ -187,10 +209,22 @@ class CovariateLabel(SetPredictor): reweighting calibration examples according to the likelihood ratio between test and calibration densities. - Paper: + The default KDE-based approach follows the CoDrug method (Laghuvarapu et al., + NeurIPS 2023), which uses kernel density estimation on embeddings to compute + likelihood ratios. Alternatively, users can provide custom weights directly + for more flexibility (e.g., from importance sampling, propensity scores, or + domain-specific methods). + + Papers: Tibshirani, Ryan J., Rina Foygel Barber, Emmanuel Candes, and Aaditya Ramdas. "Conformal prediction under covariate shift." Advances in neural information processing systems 32 (2019). + https://arxiv.org/abs/1904.06019 + + Laghuvarapu, Siddhartha, Zhen Lin, and Jimeng Sun. + "Conformal Drug Property Prediction with Density Estimation under + Covariate Shift." NeurIPS 2023. + https://arxiv.org/abs/2310.12033 Args: model: A trained base model @@ -200,13 +234,17 @@ class CovariateLabel(SetPredictor): kde_test: Optional density estimator fitted on test distribution. Should be a callable that takes embeddings (numpy array) and returns density estimates. Can be obtained via fit_kde(). + Used for KDE-based likelihood ratio weighting (CoDrug approach). kde_cal: Optional density estimator fitted on calibration distribution. Should be a callable that takes embeddings (numpy array) and returns density estimates. + Used for KDE-based likelihood ratio weighting (CoDrug approach). debug: Whether to use debug mode (processes fewer samples for faster iteration) Examples: + **Example 1: KDE-based approach (CoDrug method)** + >>> from pyhealth.datasets import ISRUCDataset >>> from pyhealth.datasets import split_by_patient, get_dataloader >>> from pyhealth.models import SparcNet @@ -240,12 +278,8 @@ class CovariateLabel(SetPredictor): >>> cal_embs = extract_embeddings(model, val_data) >>> test_embs = extract_embeddings(model, test_data) >>> - >>> # Fit KDEs - >>> kde_cal, kde_test = fit_kde(cal_embs, test_embs) - >>> - >>> # Create covariate-adaptive set predictor - >>> cal_model = CovariateLabel(model, alpha=0.1, - ... kde_test=kde_test, kde_cal=kde_cal) + >>> # KDE-based approach: automatically compute weights + >>> cal_model = CovariateLabel(model, alpha=0.1) >>> cal_model.calibrate(cal_dataset=val_data, ... cal_embeddings=cal_embs, test_embeddings=test_embs) >>> @@ -259,6 +293,19 @@ class CovariateLabel(SetPredictor): >>> print(get_metrics_fn(cal_model.mode)( ... y_true_all, y_prob_all, metrics=['accuracy', 'miscoverage_ps'], ... y_predset=extra_output['y_predset'])) + + **Example 2: Custom weights approach** + + >>> # If you have your own covariate shift correction method + >>> # (e.g., importance sampling, propensity scores, etc.) + >>> def compute_custom_weights(cal_data, test_data): + ... # Your custom weight computation + ... # Should return weights proportional to p_test(x) / p_cal(x) + ... return weights # shape: (n_cal,) + >>> + >>> custom_weights = compute_custom_weights(val_data, test_data) + >>> cal_model = CovariateLabel(model, alpha=0.1) + >>> cal_model.calibrate(cal_dataset=val_data, cal_weights=custom_weights) """ def __init__( @@ -309,42 +356,53 @@ def calibrate( cal_dataset: IterableDataset, cal_embeddings: Optional[np.ndarray] = None, test_embeddings: Optional[np.ndarray] = None, + cal_weights: Optional[np.ndarray] = None, ): """Calibrate the thresholds with covariate shift correction. + This method supports three approaches for handling covariate shift: + + 1. **KDE-based (CoDrug approach)**: Provide cal_embeddings and + test_embeddings (and optionally kde_test/kde_cal). The method will + use kernel density estimation to compute likelihood ratios. + + 2. **Custom weights**: Directly provide cal_weights computed from your + own covariate shift correction method (e.g., importance sampling, + propensity scores, discriminator-based methods, etc.). + + 3. **Pre-fitted KDEs**: Provide kde_test and kde_cal during initialization + along with cal_embeddings here. + Args: cal_dataset: Calibration set cal_embeddings: Optional pre-computed calibration embeddings of shape (n_cal, embedding_dim). If provided along with test_embeddings and KDEs are not set, will be used to - compute likelihood ratios. + compute likelihood ratios via KDE (CoDrug approach). test_embeddings: Optional pre-computed test embeddings of shape (n_test, embedding_dim). Used with cal_embeddings - for likelihood ratio computation. + for KDE-based likelihood ratio computation. + cal_weights: Optional custom weights for calibration samples + of shape (n_cal,). If provided, these weights will be used + directly instead of computing likelihood ratios via KDE. + Weights should represent importance weights or likelihood ratios + p_test(x) / p_cal(x). These will be normalized internally. Note: - You must either: - 1. Provide kde_test and kde_cal during initialization, OR - 2. Provide cal_embeddings and test_embeddings here - - If you provide embeddings, likelihood ratios will be computed - by evaluating the KDEs on the calibration embeddings only. + You must provide ONE of: + 1. cal_weights (custom weights), OR + 2. kde_test and kde_cal during initialization, OR + 3. cal_embeddings and test_embeddings here + + Examples: + >>> # Approach 1: KDE-based (CoDrug) + >>> model.calibrate(cal_dataset, cal_embeddings, test_embeddings) + >>> + >>> # Approach 2: Custom weights (e.g., from importance sampling) + >>> custom_weights = compute_importance_weights(cal_data, test_data) + >>> model.calibrate(cal_dataset, cal_weights=custom_weights) """ - # Check if we have KDEs - if self.kde_test is None or self.kde_cal is None: - if cal_embeddings is None or test_embeddings is None: - raise ValueError( - "Must provide either:\n" - " 1. kde_test and kde_cal during __init__, OR\n" - " 2. cal_embeddings and test_embeddings during " - "calibrate()" - ) - - # Fit KDEs if embeddings provided - print("Fitting KDEs on provided embeddings...") - self.kde_cal, self.kde_test = fit_kde(cal_embeddings, test_embeddings) - - # Get predictions and true labels + # Get predictions and true labels first cal_dataset_dict = prepare_numpy_dataset( self.model, cal_dataset, @@ -356,20 +414,49 @@ def calibrate( y_true = cal_dataset_dict["y_true"] N, K = y_prob.shape - # Use provided embeddings or extract from calibration data - if cal_embeddings is not None: - X = cal_embeddings + # Determine weights: either custom or KDE-based + if cal_weights is not None: + # Use custom weights provided by user + if len(cal_weights) != N: + raise ValueError( + f"cal_weights must have length {N} (size of calibration set), " + f"got {len(cal_weights)}" + ) + likelihood_ratios = np.asarray(cal_weights, dtype=np.float64) + print("Using custom calibration weights") else: - # KDEs should already be provided in this case - # We just need to get the embeddings for likelihood ratio - # This assumes the model outputs embeddings - raise NotImplementedError( - "Automatic embedding extraction not yet supported. " - "Please provide cal_embeddings and test_embeddings." - ) + # Use KDE-based approach (CoDrug method) + # Check if we have KDEs + if self.kde_test is None or self.kde_cal is None: + if cal_embeddings is None or test_embeddings is None: + raise ValueError( + "Must provide ONE of:\n" + " 1. cal_weights (custom weights), OR\n" + " 2. kde_test and kde_cal during __init__, OR\n" + " 3. cal_embeddings and test_embeddings during calibrate()" + ) - # Compute likelihood ratios for covariate shift correction - likelihood_ratios = _compute_likelihood_ratio(self.kde_test, self.kde_cal, X) + # Fit KDEs if embeddings provided + print("Fitting KDEs on provided embeddings (CoDrug approach)...") + self.kde_cal, self.kde_test = fit_kde(cal_embeddings, test_embeddings) + + # Use provided embeddings or extract from calibration data + if cal_embeddings is not None: + X = cal_embeddings + else: + # KDEs should already be provided in this case + # We just need to get the embeddings for likelihood ratio + # This assumes the model outputs embeddings + raise NotImplementedError( + "Automatic embedding extraction not yet supported. " + "Please provide cal_embeddings and test_embeddings." + ) + + # Compute likelihood ratios using KDE + print("Computing likelihood ratios via KDE...") + likelihood_ratios = _compute_likelihood_ratio( + self.kde_test, self.kde_cal, X + ) # Normalize weights weights = likelihood_ratios / np.sum(likelihood_ratios) @@ -419,11 +506,18 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: if __name__ == "__main__": - # Example usage (requires actual density estimators) + """ + Demonstration of three approaches for covariate shift correction: + 1. Embeddings approach: Automatic KDE computation (CoDrug method) + 2. Pre-fitted KDEs approach: User provides KDE estimators + 3. Custom weights approach: User provides custom importance weights + """ from pyhealth.datasets import ISRUCDataset, split_by_patient, get_dataloader from pyhealth.models import SparcNet from pyhealth.tasks import sleep_staging_isruc_fn + from pyhealth.trainer import Trainer, get_metrics_fn + # Setup data and model sleep_ds = ISRUCDataset("/srv/local/data/trash", dev=True).set_task( sleep_staging_isruc_fn ) @@ -432,27 +526,144 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: model = SparcNet( dataset=sleep_ds, feature_keys=["signal"], label_key="label", mode="multiclass" ) + # ... Train the model here ... + + # Helper function to extract embeddings (mock implementation) + def extract_embeddings(model, dataset): + """Extract embeddings from model for a dataset.""" + # In practice, you would do: + # loader = get_dataloader(dataset, batch_size=32, shuffle=False) + # all_embs = [] + # for batch in loader: + # batch['embed'] = True + # output = model(**batch) + # all_embs.append(output['embed'].cpu().numpy()) + # return np.concatenate(all_embs, axis=0) + + # For demo, return random embeddings + n_samples = len(dataset) + embedding_dim = 64 + return np.random.randn(n_samples, embedding_dim) + + print("=" * 80) + print("APPROACH 1: Embeddings (Automatic KDE - CoDrug Method)") + print("=" * 80) + print("This approach automatically computes KDEs from embeddings.") + print("Best for: When you have model embeddings and want automatic density estimation.\n") + + # Extract embeddings from calibration and test sets + cal_embeddings = extract_embeddings(model, val_data) + test_embeddings = extract_embeddings(model, test_data) + + # Create model and calibrate with embeddings + cal_model_1 = CovariateLabel(model, alpha=0.1) + cal_model_1.calibrate( + cal_dataset=val_data, + cal_embeddings=cal_embeddings, + test_embeddings=test_embeddings + ) - # Note: In practice, you would fit proper KDE estimators here - # For demonstration, using dummy estimators - def dummy_kde(data): - return np.ones(len(data)) + # Evaluate + test_dl = get_dataloader(test_data, batch_size=32, shuffle=False) + y_true, y_prob, _, extra = Trainer(model=cal_model_1).inference( + test_dl, additional_outputs=["y_predset"] + ) + metrics_1 = get_metrics_fn(cal_model_1.mode)( + y_true, y_prob, + metrics=["accuracy", "miscoverage_ps"], + y_predset=extra["y_predset"] + ) + print(f"Results: {metrics_1}\n") + + print("=" * 80) + print("APPROACH 2: Pre-fitted KDEs") + print("=" * 80) + print("This approach uses pre-computed KDE estimators.") + print("Best for: When you want control over KDE parameters or reuse KDEs.\n") + + # Fit KDEs separately with custom parameters + kde_cal, kde_test = fit_kde( + cal_embeddings, + test_embeddings, + bandwidth=0.5, # Custom bandwidth + kernel="rbf" + ) - cal_model = CovariateLabel(model, alpha=0.1, kde_test=dummy_kde, kde_cal=dummy_kde) - cal_model.calibrate(cal_dataset=val_data) + # Create model with pre-fitted KDEs + cal_model_2 = CovariateLabel( + model, + alpha=0.1, + kde_test=kde_test, + kde_cal=kde_cal + ) + cal_model_2.calibrate( + cal_dataset=val_data, + cal_embeddings=cal_embeddings # Still need embeddings for likelihood ratio computation + ) # Evaluate - from pyhealth.trainer import Trainer, get_metrics_fn + y_true, y_prob, _, extra = Trainer(model=cal_model_2).inference( + test_dl, additional_outputs=["y_predset"] + ) + metrics_2 = get_metrics_fn(cal_model_2.mode)( + y_true, y_prob, + metrics=["accuracy", "miscoverage_ps"], + y_predset=extra["y_predset"] + ) + print(f"Results: {metrics_2}\n") + + print("=" * 80) + print("APPROACH 3: Custom Weights") + print("=" * 80) + print("This approach uses user-provided importance weights.") + print("Best for: Alternative covariate shift methods (importance sampling,") + print(" propensity scores, discriminator-based, domain-specific).\n") + + # Compute custom weights using your own method + # Examples of custom weight computation: + + # Option A: Uniform weights (no covariate shift correction) + custom_weights = np.ones(len(val_data)) + + # Option B: Importance sampling weights (mock example) + # In practice, you might use: + # - Discriminator-based methods + # - Propensity score matching + # - Domain adaptation techniques + # - Energy-based models + # custom_weights = compute_importance_weights(val_data, test_data) + + # Option C: Exponential weights based on distance (mock example) + # distances = compute_distribution_distances(val_data, test_data) + # custom_weights = np.exp(-distances) + + print(f"Using custom weights (shape: {custom_weights.shape})") + + # Create model and calibrate with custom weights + cal_model_3 = CovariateLabel(model, alpha=0.1) + cal_model_3.calibrate( + cal_dataset=val_data, + cal_weights=custom_weights # Provide weights directly + ) - test_dl = get_dataloader(test_data, batch_size=32, shuffle=False) - y_true_all, y_prob_all, _, extra_output = Trainer(model=cal_model).inference( + # Evaluate + y_true, y_prob, _, extra = Trainer(model=cal_model_3).inference( test_dl, additional_outputs=["y_predset"] ) - print( - get_metrics_fn(cal_model.mode)( - y_true_all, - y_prob_all, - metrics=["accuracy", "miscoverage_ps"], - y_predset=extra_output["y_predset"], - ) + metrics_3 = get_metrics_fn(cal_model_3.mode)( + y_true, y_prob, + metrics=["accuracy", "miscoverage_ps"], + y_predset=extra["y_predset"] ) + print(f"Results: {metrics_3}\n") + + print("=" * 80) + print("SUMMARY") + print("=" * 80) + print("Approach 1 (Embeddings): ", metrics_1) + print("Approach 2 (Pre-fitted KDEs):", metrics_2) + print("Approach 3 (Custom Weights): ", metrics_3) + print("\nAll three approaches are valid and can be chosen based on your needs!") + print("- Use Approach 1 for simplicity with embeddings (CoDrug method)") + print("- Use Approach 2 for fine-grained control over KDE parameters") + print("- Use Approach 3 for alternative covariate shift correction methods")