Skip to content

Commit

Permalink
feat: local c2st metric (#1109)
Browse files Browse the repository at this point in the history
* lc2st class - first imp

* null hypothesis

* start of notebook

* LC2ST class and notebook version 2

* notebook with graphical diagnostics on GaussianMixture

* missing text in notebook and final fixes

* move gaussian_mixture model from utils to sbi.simulators

* ruff fix

* ruff fix

* typing fix and small doc changes

* fix bug in return `statistic_data`, no prepare_for_sbi in notebook

* bug fix in args of `statistics_data`

* fixes suggested by reviewer @agramfort, doc and typing

* changes simulator gaussian_mixture

* variable name changes pep8 and sbi convention

* more explicit method names and custom clf_kwargs

* remove pandas dependency

* tutorial results description and other fixes for PR

* clarifications lc2st-nf

* ruff check fix

* 10 --> 100 test runs

* negatif --> negativ

* rebase changes + pytest fix

* ensembling

* ruff fix

* add reference for pp-plot

Co-authored-by: Peter Steinbach <p.steinbach@hzdr.de>

* tutorial changes and ruff

* change the default n_ensemble back to 1, explain in tutorial and lc2st doc

* change the default n_ensemble back to 1, explain in tutorial and lc2st doc

* ensembling, clf-choice in tutorial, lc2st-nf description in doc

* pyright fix

---------

Co-authored-by: Peter Steinbach <p.steinbach@hzdr.de>
Co-authored-by: Jan <janfb@users.noreply.github.com>
  • Loading branch information
3 people committed May 17, 2024
1 parent fe55b1c commit 3c1e725
Show file tree
Hide file tree
Showing 9 changed files with 2,426 additions and 17 deletions.
3 changes: 3 additions & 0 deletions sbi/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
conditional_marginal_plot,
conditional_pairplot,
marginal_plot,
marginal_plot_with_probs_intensity,
pairplot,
pp_plot,
pp_plot_lc2st,
sbc_rank_plot,
)
from sbi.analysis.sensitivity_analysis import ActiveSubspace
Expand Down
240 changes: 239 additions & 1 deletion sbi/analysis/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,24 @@
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

import collections
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
from warnings import warn

import matplotlib as mpl
import numpy as np
import six
import torch
from matplotlib import cm
from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from matplotlib.colors import Normalize
from matplotlib.figure import Figure, FigureBase
from matplotlib.patches import Rectangle
from scipy.stats import binom, gaussian_kde, iqr
from torch import Tensor

from sbi.analysis import eval_conditional_density
from sbi.utils.analysis_utils import pp_vals

try:
collectionsAbc = collections.abc # type: ignore
Expand Down Expand Up @@ -1832,6 +1836,240 @@ def _plot_hist_region_expected_under_uniformity(
)


# Diagnostics for hypothesis tests


def pp_plot(
scores: Union[List[np.ndarray], Dict[Any, np.ndarray]],
scores_null: Union[List[np.ndarray], Dict[Any, np.ndarray]],
true_scores_null: np.ndarray,
conf_alpha: float,
n_alphas: int = 100,
labels: Optional[List[str]] = None,
colors: Optional[List[str]] = None,
ax: Optional[Axes] = None,
**kwargs: Any,
) -> Axes:
"""Probability - Probability (P-P) plot for hypothesis tests
to assess the validity of one (or several) estimator(s).
See [here](https://en.wikipedia.org/wiki/P%E2%80%93P_plot) for more details.
Args:
scores: test scores estimated on observed data and evaluated on the test set,
of shape (n_eval,). One array per estimator.
scores_null: test scores estimated under the null hypothesis and evaluated on
the test set, of shape (n_eval,). One array per null trial.
true_scores_null: theoretical true scores under the null hypothesis,
of shape (n_eval,).
labels: labels for the estimators, defaults to None.
colors: colors for the estimators, defaults to None.
conf_alpha: significanecee level of the hypothesis test.
n_alphas: number of cdf-values to compute the P-P plot, defaults to 100.
ax: axis to plot on, defaults to None.
kwargs: additional arguments for matplotlib plotting.
Returns:
ax: axes with the P-P plot.
"""
if ax is None:
ax = plt.gca()
ax_: Axes = cast(Axes, ax) # cast to fix pyright error

alphas = np.linspace(0, 1, n_alphas)

# pp_vals for the true null hypothesis
pp_vals_true = pp_vals(true_scores_null, alphas)
ax_.plot(alphas, pp_vals_true, "--", color="black", label="True Null (H0)")

# pp_vals for the estimated null hypothesis over the multiple trials
pp_vals_null = []
for t in range(len(scores_null)):
pp_vals_null.append(pp_vals(scores_null[t], alphas))
pp_vals_null = np.array(pp_vals_null)

# confidence region
quantiles = np.quantile(pp_vals_null, [conf_alpha / 2, 1 - conf_alpha / 2], axis=0)
ax_.fill_between(
alphas,
quantiles[0],
quantiles[1],
color="grey",
alpha=0.2,
label=f"{(1 - conf_alpha) * 100}% confidence region",
)

# pp_vals for the observed data
for i, p_ in enumerate(scores):
pp_vals_o = pp_vals(p_, alphas)
if labels is not None:
kwargs["label"] = labels[i]
if colors is not None:
kwargs["color"] = colors[i]
ax_.plot(alphas, pp_vals_o, **kwargs)
return ax_


def marginal_plot_with_probs_intensity(
probs_per_marginal: dict,
marginal_dim: int,
n_bins: int = 20,
vmin: float = 0.0,
vmax: float = 1.0,
cmap_name: str = "Spectral_r",
show_colorbar: bool = True,
label: Optional[str] = None,
ax: Optional[Axes] = None,
) -> Axes:
"""Plot 1d or 2d marginal histogram of samples of the density estimator
with probabilities as color intensity.
Args:
probs_per_marginal: dataframe with predicted class probabilities
as obtained from `sbi.utils.analysis_utils.get_probs_per_marginal`.
marginal_dim: dimension of the marginal histogram to plot.
n_bins: number of bins for the histogram, defaults to 20.
vmin: minimum value for the color intensity, defaults to 0.
vmax: maximum value for the color intensity, defaults to 1.
cmap: colormap for the color intensity, defaults to "Spectral_r".
show_colorbar: whether to show the colorbar, defaults to True.
label: label for the colorbar, defaults to None.
ax (matplotlib.axes.Axes): axes to plot on, defaults to None.
Returns:
ax (matplotlib.axes.Axes): axes with the plot.
"""
assert marginal_dim in [1, 2], "Only 1d or 2d marginals are supported."

if ax is None:
ax = plt.gca()
ax_: Axes = cast(Axes, ax) # cast to fix pyright error

if label is None:
label = "probability"

# get colormap
cmap = cm.get_cmap(cmap_name)

# case of 1d marginal
if marginal_dim == 1:
# extract bins and patches
_, bins, patches = ax_.hist(
probs_per_marginal['s'], n_bins, density=True, color="green"
)
# create bins: all samples between bin edges are assigned to the same bin
probs_per_marginal["bins"] = np.searchsorted(bins, probs_per_marginal['s']) - 1
probs_per_marginal["bins"][probs_per_marginal["bins"] < 0] = 0
# get mean prob for each bin (same as pandas groupy method)
array_probs = np.concatenate(
[probs_per_marginal['bins'][:, None], probs_per_marginal['probs'][:, None]],
axis=1,
)
array_probs = array_probs[array_probs[:, 0].argsort()]
weights = np.split(
array_probs[:, 1], np.unique(array_probs[:, 0], return_index=True)[1][1:]
)
weights = np.array([np.mean(w) for w in weights])
# remove empty bins
id = list(set(range(n_bins)) - set(probs_per_marginal['bins']))
patches = np.delete(patches, id)
bins = np.delete(bins, id)

# normalize color intensity
norm = Normalize(vmin=vmin, vmax=vmax)
# set color intensity
for w, p in zip(weights, patches):
p.set_facecolor(cmap(w))
if show_colorbar:
plt.colorbar(cm.ScalarMappable(norm=norm, cmap=cmap), ax=ax_, label=label)

if marginal_dim == 2:
# extract bin edges
_, x, y = np.histogram2d(
probs_per_marginal['s_1'], probs_per_marginal['s_2'], bins=n_bins
)
# create bins: all samples between bin edges are assigned to the same bin
probs_per_marginal["bins_x"] = np.searchsorted(x, probs_per_marginal['s_1']) - 1
probs_per_marginal["bins_y"] = np.searchsorted(y, probs_per_marginal['s_2']) - 1
probs_per_marginal["bins_x"][probs_per_marginal["bins_x"] < 0] = 0
probs_per_marginal["bins_y"][probs_per_marginal["bins_y"] < 0] = 0

# extract unique bin pairs
group_idx = np.concatenate(
[
probs_per_marginal['bins_x'][:, None],
probs_per_marginal['bins_y'][:, None],
],
axis=1,
)
unique_bins = np.unique(group_idx, return_counts=True, axis=0)[0]

# get mean prob for each bin (same as pandas groupy method)
mean_probs = np.zeros((len(unique_bins),))
for i in range(len(unique_bins)):
idx = np.where((group_idx == unique_bins[i]).all(axis=1))
mean_probs[i] = np.mean(probs_per_marginal['probs'][idx])

# create weight matrix with nan values for non-existing bins
weights = np.zeros((n_bins, n_bins))
weights[:] = np.nan
weights[unique_bins[:, 0], unique_bins[:, 1]] = mean_probs

# set color intensity
norm = Normalize(vmin=vmin, vmax=vmax)
for i in range(len(x) - 1):
for j in range(len(y) - 1):
facecolor = cmap(norm(weights.T[j, i]))
# if no sample in bin, set color to white
if weights.T[j, i] == np.nan:
facecolor = "white"
rect = Rectangle(
(x[i], y[j]),
x[i + 1] - x[i],
y[j + 1] - y[j],
facecolor=facecolor,
edgecolor="none",
)
ax_.add_patch(rect)
if show_colorbar:
plt.colorbar(cm.ScalarMappable(norm=norm, cmap=cmap), ax=ax_, label=label)

return ax_


# Customized plotting functions for LC2ST


def pp_plot_lc2st(
probs: Union[List[np.ndarray], Dict[Any, np.ndarray]],
probs_null: Union[List[np.ndarray], Dict[Any, np.ndarray]],
conf_alpha: float,
**kwargs: Any,
) -> Axes:
"""Probability - Probability (P-P) plot for LC2ST.
Args:
probs: predicted probability on observed data and evaluated on the test set,
of shape (n_eval,). One array per estimator.
probs_null: predicted probability under the null hypothesis and evaluated on
the test set, of shape (n_eval,). One array per null trial.
conf_alpha: significanecee level of the hypothesis test.
kwargs: additional arguments for `pp_plot`.
Returns:
ax: axes with the P-P plot.
"""
# probability at chance level (under the null) is 0.5
true_probs_null = np.array([0.5] * len(probs))
return pp_plot(
scores=probs,
scores_null=probs_null,
true_scores_null=true_probs_null,
conf_alpha=conf_alpha,
**kwargs,
)


# TO BE DEPRECATED
# ----------------
def pairplot_dep(
Expand Down

0 comments on commit 3c1e725

Please sign in to comment.