Skip to content

Commit

Permalink
Omit nans in plotting functions, cast to numpy before plotting, fix a…
Browse files Browse the repository at this point in the history
…xes indexing for 1D (#1185)

* cast to numpy

* omit nans, dont index axes if 1D

* black

* revert back

* rename

* ruffed

* ignore matplotlib pyright

* more ignore matplotlib pyright

* ruffed

* remove duplicate tutorial notebook

* fix typos in notebook

* add plotting to docs

* added tests for plotting

---------

Co-authored-by: Matthijs <matthijs@example.com>
  • Loading branch information
Matthijspals and Matthijs committed Jun 26, 2024
1 parent 6f61662 commit 1db7aaf
Show file tree
Hide file tree
Showing 5 changed files with 638 additions and 503 deletions.
1 change: 1 addition & 0 deletions docs/mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ nav:
- Analysis:
- Conditional distributions: tutorial/07_conditional_distributions.md
- Posterior sensitivity analysis: tutorial/09_sensitivity_analysis.md
- Pair and marginal plots: tutorial/19_plotting_functionality.md
- Examples:
- Hodgkin-Huxley example: examples/00_HH_simulator.md
- Decision making model: examples/01_decision_making_model.md
Expand Down
128 changes: 76 additions & 52 deletions sbi/analysis/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

import collections
import logging
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
from warnings import warn

Expand Down Expand Up @@ -81,23 +82,23 @@ def plt_hist_1d(
) -> None:
"""Plot 1D histogram."""
if (
"bins" not in diag_kwargs['mpl_kwargs']
or diag_kwargs['mpl_kwargs']["bins"] is None
"bins" not in diag_kwargs["mpl_kwargs"]
or diag_kwargs["mpl_kwargs"]["bins"] is None
):
if diag_kwargs["bin_heuristic"] == "Freedman-Diaconis":
# The Freedman-Diaconis heuristic
binsize = 2 * iqr(samples) * len(samples) ** (-1 / 3)
diag_kwargs['mpl_kwargs']["bins"] = np.arange(
diag_kwargs["mpl_kwargs"]["bins"] = np.arange(
limits[0], limits[1] + binsize, binsize
)
else:
# TODO: add more bin heuristics
pass
if isinstance(diag_kwargs['mpl_kwargs']["bins"], int):
diag_kwargs['mpl_kwargs']["bins"] = np.linspace(
limits[0], limits[1], diag_kwargs['mpl_kwargs']["bins"]
if isinstance(diag_kwargs["mpl_kwargs"]["bins"], int):
diag_kwargs["mpl_kwargs"]["bins"] = np.linspace(
limits[0], limits[1], diag_kwargs["mpl_kwargs"]["bins"]
)
ax.hist(samples, **diag_kwargs['mpl_kwargs'])
ax.hist(samples, **diag_kwargs["mpl_kwargs"])


def plt_kde_1d(
Expand All @@ -110,7 +111,7 @@ def plt_kde_1d(
density = gaussian_kde(samples, bw_method=diag_kwargs["bw_method"])
xs = np.linspace(limits[0], limits[1], diag_kwargs["bins"])
ys = density(xs)
ax.plot(xs, ys, **diag_kwargs['mpl_kwargs'])
ax.plot(xs, ys, **diag_kwargs["mpl_kwargs"])


def plt_scatter_1d(
Expand All @@ -121,7 +122,7 @@ def plt_scatter_1d(
) -> None:
"""Plot vertical lines for each sample. Note: limits are not used."""
for single_sample in samples:
ax.axvline(single_sample, **diag_kwargs['mpl_kwargs'])
ax.axvline(single_sample, **diag_kwargs["mpl_kwargs"])


def plt_hist_2d(
Expand All @@ -134,16 +135,16 @@ def plt_hist_2d(
):
"""Plot 2D histogram."""
if (
"bins" not in offdiag_kwargs['np_hist_kwargs']
or offdiag_kwargs['np_hist_kwargs']["bins"] is None
"bins" not in offdiag_kwargs["np_hist_kwargs"]
or offdiag_kwargs["np_hist_kwargs"]["bins"] is None
):
if offdiag_kwargs["bin_heuristic"] == "Freedman-Diaconis":
# The Freedman-Diaconis heuristic applied to each direction
binsize_col = 2 * iqr(samples_col) * len(samples_col) ** (-1 / 3)
n_bins_col = int((limits_col[1] - limits_col[0]) / binsize_col)
binsize_row = 2 * iqr(samples_row) * len(samples_row) ** (-1 / 3)
n_bins_row = int((limits_row[1] - limits_row[0]) / binsize_row)
offdiag_kwargs['np_hist_kwargs']["bins"] = [n_bins_col, n_bins_row]
offdiag_kwargs["np_hist_kwargs"]["bins"] = [n_bins_col, n_bins_row]
else:
# TODO: add more bin heuristics
pass
Expand All @@ -154,7 +155,7 @@ def plt_hist_2d(
[limits_col[0], limits_col[1]],
[limits_row[0], limits_row[1]],
],
**offdiag_kwargs['np_hist_kwargs'],
**offdiag_kwargs["np_hist_kwargs"],
)
ax.imshow(
hist.T,
Expand All @@ -164,7 +165,7 @@ def plt_hist_2d(
yedges[0],
yedges[-1],
),
**offdiag_kwargs['mpl_kwargs'],
**offdiag_kwargs["mpl_kwargs"],
)


Expand All @@ -187,7 +188,7 @@ def plt_kde_2d(
limits_row[0],
limits_row[1],
),
**offdiag_kwargs['mpl_kwargs'],
**offdiag_kwargs["mpl_kwargs"],
)


Expand All @@ -214,7 +215,7 @@ def plt_contour_2d(
limits_row[1],
),
levels=offdiag_kwargs["levels"],
**offdiag_kwargs['mpl_kwargs'],
**offdiag_kwargs["mpl_kwargs"],
)


Expand All @@ -230,7 +231,7 @@ def plt_scatter_2d(
ax.scatter(
samples_col,
samples_row,
**offdiag_kwargs['mpl_kwargs'],
**offdiag_kwargs["mpl_kwargs"],
)


Expand All @@ -247,7 +248,7 @@ def plt_plot_2d(
ax.plot(
samples_col,
samples_row,
**offdiag_kwargs['mpl_kwargs'],
**offdiag_kwargs["mpl_kwargs"],
)


Expand Down Expand Up @@ -304,11 +305,11 @@ def get_diag_funcs(
"""make a list of the functions for the diagonal plots."""
diag_funcs = []
for diag in diag_list:
if diag == 'hist':
if diag == "hist":
diag_funcs.append(plt_hist_1d)
elif diag == 'kde':
elif diag == "kde":
diag_funcs.append(plt_kde_1d)
elif diag == 'scatter':
elif diag == "scatter":
diag_funcs.append(plt_scatter_1d)
else:
diag_funcs.append(None)
Expand All @@ -335,15 +336,15 @@ def get_offdiag_funcs(
"""make a list of the functions for the off-diagonal plots."""
offdiag_funcs = []
for offdiag in off_diag_list:
if offdiag == 'hist' or offdiag == 'hist2d':
if offdiag == "hist" or offdiag == "hist2d":
offdiag_funcs.append(plt_hist_2d)
elif offdiag == 'kde' or offdiag == 'kde2d':
elif offdiag == "kde" or offdiag == "kde2d":
offdiag_funcs.append(plt_kde_2d)
elif offdiag == 'contour' or offdiag == 'contourf':
elif offdiag == "contour" or offdiag == "contourf":
offdiag_funcs.append(plt_contour_2d)
elif offdiag == 'scatter':
elif offdiag == "scatter":
offdiag_funcs.append(plt_scatter_2d)
elif offdiag == 'plot':
elif offdiag == "plot":
offdiag_funcs.append(plt_plot_2d)
else:
offdiag_funcs.append(None)
Expand Down Expand Up @@ -534,8 +535,24 @@ def ensure_numpy(t: Union[np.ndarray, torch.Tensor]) -> np.ndarray:
"""
if isinstance(t, torch.Tensor):
return t.numpy()
else:
return t
elif not isinstance(t, np.ndarray):
return np.array(t)
return t


def handle_nan_infs(samples: List[np.ndarray]) -> List[np.ndarray]:
"""Check if there are NaNs or Infs in the samples."""
for i in range(len(samples)):
if np.isnan(samples[i]).any():
logging.warning("NaNs found in samples, omitting datapoints.")
if np.isinf(samples[i]).any():
logging.warning("Infs found in samples, omitting datapoints.")
# cast inf to nan, so they are omitted in the next step
np.nan_to_num(
samples[i], copy=False, nan=np.nan, posinf=np.nan, neginf=np.nan
)
samples[i] = samples[i][~np.isnan(samples[i]).any(axis=1)]
return samples


def prepare_for_plot(
Expand All @@ -554,6 +571,9 @@ def prepare_for_plot(
else:
samples = [ensure_numpy(sample) for sample in samples]

# check if nans and infs
samples = handle_nan_infs(samples)

# Dimensionality of the problem.
dim = samples[0].shape[1]

Expand All @@ -564,9 +584,9 @@ def prepare_for_plot(
min = +np.inf
max = -np.inf
for sample in samples:
min_ = sample[:, d].min()
min_ = np.min(sample[:, d])
min = min_ if min_ < min else min
max_ = sample[:, d].max()
max_ = np.max(sample[:, d])
max = max_ if max_ > max else max
limits.append([min, max])
else:
Expand Down Expand Up @@ -1277,7 +1297,7 @@ def _arrange_grid(
excl_upper = all(v is None for v in upper_funcs)
excl_diag = all(v is None for v in diag_funcs)
flat = excl_lower and excl_upper

one_dim = dim == 1
# select the subset of rows and cols to be plotted
if flat:
rows = 1
Expand Down Expand Up @@ -1309,11 +1329,15 @@ def _arrange_grid(
else:
current = "lower"

ax = axes[col_idx] if flat else axes[row_idx, col_idx] # pyright: ignore reportIndexIssue

if one_dim:
ax = axes # pyright: ignore reportIndexIssue
elif flat:
ax = axes[col_idx] # pyright: ignore reportIndexIssue
else:
ax = axes[row_idx, col_idx] # pyright: ignore reportIndexIssue
# Diagonals
_format_subplot(
ax,
ax, # pyright: ignore reportArgumentType
current,
limits,
ticks,
Expand All @@ -1327,7 +1351,7 @@ def _arrange_grid(
)
if current == "diag":
if excl_diag:
ax.axis("off")
ax.axis("off") # pyright: ignore reportOptionalMemberAccess
else:
for sample_ind, sample in enumerate(samples):
diag_f = diag_funcs[sample_ind]
Expand All @@ -1337,24 +1361,24 @@ def _arrange_grid(
)

if len(points) > 0:
extent = ax.get_ylim()
extent = ax.get_ylim() # pyright: ignore reportOptionalMemberAccess
for n, v in enumerate(points):
ax.plot(
ax.plot( # pyright: ignore reportOptionalMemberAccess
[v[:, col], v[:, col]],
extent,
color=fig_kwargs["points_colors"][n],
**fig_kwargs["points_diag"],
label=fig_kwargs["points_labels"][n],
)
if fig_kwargs["legend"] and col == 0:
ax.legend(**fig_kwargs["legend_kwargs"])
ax.legend(**fig_kwargs["legend_kwargs"]) # pyright: ignore reportOptionalMemberAccess

# Off-diagonals

# upper
elif current == "upper":
if excl_upper:
ax.axis("off")
ax.axis("off") # pyright: ignore reportOptionalMemberAccess
else:
for sample_ind, sample in enumerate(samples):
upper_f = upper_funcs[sample_ind]
Expand All @@ -1369,7 +1393,7 @@ def _arrange_grid(
)
if len(points) > 0:
for n, v in enumerate(points):
ax.plot(
ax.plot( # pyright: ignore reportOptionalMemberAccess
v[:, col],
v[:, row],
color=fig_kwargs["points_colors"][n],
Expand All @@ -1378,7 +1402,7 @@ def _arrange_grid(
# lower
elif current == "lower":
if excl_lower:
ax.axis("off")
ax.axis("off") # pyright: ignore reportOptionalMemberAccess
else:
for sample_ind, sample in enumerate(samples):
lower_f = lower_funcs[sample_ind]
Expand All @@ -1393,7 +1417,7 @@ def _arrange_grid(
)
if len(points) > 0:
for n, v in enumerate(points):
ax.plot(
ax.plot( # pyright: ignore reportOptionalMemberAccess
v[:, col],
v[:, row],
color=fig_kwargs["points_colors"][n],
Expand Down Expand Up @@ -1955,14 +1979,14 @@ def marginal_plot_with_probs_intensity(
if marginal_dim == 1:
# extract bins and patches
_, bins, patches = ax_.hist(
probs_per_marginal['s'], n_bins, density=True, color="green"
probs_per_marginal["s"], n_bins, density=True, color="green"
)
# create bins: all samples between bin edges are assigned to the same bin
probs_per_marginal["bins"] = np.searchsorted(bins, probs_per_marginal['s']) - 1
probs_per_marginal["bins"] = np.searchsorted(bins, probs_per_marginal["s"]) - 1
probs_per_marginal["bins"][probs_per_marginal["bins"] < 0] = 0
# get mean prob for each bin (same as pandas groupy method)
array_probs = np.concatenate(
[probs_per_marginal['bins'][:, None], probs_per_marginal['probs'][:, None]],
[probs_per_marginal["bins"][:, None], probs_per_marginal["probs"][:, None]],
axis=1,
)
array_probs = array_probs[array_probs[:, 0].argsort()]
Expand All @@ -1971,7 +1995,7 @@ def marginal_plot_with_probs_intensity(
)
weights = np.array([np.mean(w) for w in weights])
# remove empty bins
id = list(set(range(n_bins)) - set(probs_per_marginal['bins']))
id = list(set(range(n_bins)) - set(probs_per_marginal["bins"]))
patches = np.delete(patches, id)
bins = np.delete(bins, id)

Expand All @@ -1986,19 +2010,19 @@ def marginal_plot_with_probs_intensity(
if marginal_dim == 2:
# extract bin edges
_, x, y = np.histogram2d(
probs_per_marginal['s_1'], probs_per_marginal['s_2'], bins=n_bins
probs_per_marginal["s_1"], probs_per_marginal["s_2"], bins=n_bins
)
# create bins: all samples between bin edges are assigned to the same bin
probs_per_marginal["bins_x"] = np.searchsorted(x, probs_per_marginal['s_1']) - 1
probs_per_marginal["bins_y"] = np.searchsorted(y, probs_per_marginal['s_2']) - 1
probs_per_marginal["bins_x"] = np.searchsorted(x, probs_per_marginal["s_1"]) - 1
probs_per_marginal["bins_y"] = np.searchsorted(y, probs_per_marginal["s_2"]) - 1
probs_per_marginal["bins_x"][probs_per_marginal["bins_x"] < 0] = 0
probs_per_marginal["bins_y"][probs_per_marginal["bins_y"] < 0] = 0

# extract unique bin pairs
group_idx = np.concatenate(
[
probs_per_marginal['bins_x'][:, None],
probs_per_marginal['bins_y'][:, None],
probs_per_marginal["bins_x"][:, None],
probs_per_marginal["bins_y"][:, None],
],
axis=1,
)
Expand All @@ -2008,7 +2032,7 @@ def marginal_plot_with_probs_intensity(
mean_probs = np.zeros((len(unique_bins),))
for i in range(len(unique_bins)):
idx = np.where((group_idx == unique_bins[i]).all(axis=1))
mean_probs[i] = np.mean(probs_per_marginal['probs'][idx])
mean_probs[i] = np.mean(probs_per_marginal["probs"][idx])

# create weight matrix with nan values for non-existing bins
weights = np.zeros((n_bins, n_bins))
Expand Down
Loading

0 comments on commit 1db7aaf

Please sign in to comment.