Skip to content

Commit

Permalink
Improvements posterior routine and high-dim example.
Browse files Browse the repository at this point in the history
  • Loading branch information
cweniger committed Feb 14, 2021
1 parent 7a576d4 commit fbdd291
Show file tree
Hide file tree
Showing 4 changed files with 867 additions and 16 deletions.
285 changes: 285 additions & 0 deletions notebooks/Development - High-dim Quality.ipynb

Large diffs are not rendered by default.

555 changes: 555 additions & 0 deletions notebooks/Development - Ratio plots.ipynb

Large diffs are not rendered by default.

17 changes: 11 additions & 6 deletions swyft/plot.py
Expand Up @@ -3,7 +3,7 @@
from scipy.interpolate import griddata

from .types import Array, Sequence, Tuple
from .utils import verbosity
from .utils import verbosity, grid_interpolate_samples

# def get_contour_levels(x, cred_level=[0.68268, 0.95450, 0.99730]):
# x = np.sort(x)[::-1] # Sort backwards
Expand Down Expand Up @@ -65,6 +65,7 @@ def plot1d(
ncol=None,
truth=None,
bins=100,
grid_interpolate=False,
label_args={},
) -> None:

Expand All @@ -90,7 +91,7 @@ def plot1d(
else:
i, j = k % ncol, k // ncol
ax = axes[j, i]
plot_posterior(post, params[k], ax=ax, color=color, bins=bins)
plot_posterior(post, params[k], ax=ax, grid_interpolate=grid_interpolate, color=color, bins=bins)
ax.set_xlabel(labels[k], **label_args)
if truth is not None:
ax.axvline(truth[params[k]], ls=":", color="r")
Expand Down Expand Up @@ -185,7 +186,7 @@ def contour1d(z, v, levels, ax=plt, linestyles=None, color=None, **kwargs):


def plot_posterior(
post, params, weights_key=None, ax=plt, bins=100, color="k", **kwargs
post, params, weights_key=None, ax=plt, grid_interpolate = False, bins=100, color="k", **kwargs
):
if isinstance(params, str):
params = (params,)
Expand All @@ -199,9 +200,13 @@ def plot_posterior(

if len(params) == 1:
x = post["params"][params[0]]
# v, e, _ = ax.hist(x, weights = w, bins = bins, color = color, alpha = 0.2)
v, e = np.histogram(x, weights=w, bins=bins, density=True)
zm = (e[1:] + e[:-1]) / 2

if grid_interpolate:
zm, v = grid_interpolate_samples(x, w)
else:
v, e = np.histogram(x, weights=w, bins=bins, density=True)
zm = (e[1:] + e[:-1]) / 2

levels = sorted(get_contour_levels(v))
contour1d(zm, v, levels, ax=ax, color=color)
ax.plot(zm, v, color=color, **kwargs)
Expand Down
26 changes: 16 additions & 10 deletions swyft/utils.py
Expand Up @@ -375,6 +375,17 @@ def swyftify_observation(observation: torch.Tensor):
def unswyftify_observation(swyft_observation: dict):
return swyft_observation["x"]

def grid_interpolate_samples(x, y, bins = 1000, return_norm = False):
idx = np.argsort(x)
x, y = x[idx], y[idx]
x_grid = np.linspace(x[0], x[-1], bins)
y_grid = np.interp(x_grid, x, y)
norm = simps(y_grid, x_grid)
y_grid_normed = y_grid/norm
if return_norm:
return x_grid, y_grid_normed, norm
else:
return x_grid, y_grid_normed

def get_entropy(x, y, y_true = None, bins = 1000):
"""Estimate 1-dim entropy, norm and KL divergence.
Expand All @@ -385,20 +396,15 @@ def get_entropy(x, y, y_true = None, bins = 1000):
y_true (function): functional form of the true probability density for KL calculation
bins (int): Number of bins to use for interpolation.
"""
idx = np.argsort(x)
x, y = x[idx], y[idx]
x_grid = np.linspace(x[0], x[-1], bins)
y_grid = np.interp(x_grid, x, y)
norm = simps(y_grid, x_grid)
y_grid_normed = y_grid/norm
entropy = simps(y_grid_normed*np.log(y_grid_normed), x_grid)
x_int, y_int, norm = grid_interpolate_samples(x, y, bins = bins, return_norm = True)
entropy = simps(y_int*np.log(y_int), x_int)
if y_true is not None:
y_grid_true = y_true(x_grid)
KL = simps(y_grid_normed*np.log(y_grid_normed/y_grid_true), x_grid)
y_int_true = y_true(x_int)
KL = simps(y_int*np.log(y_int/y_int_true), x_int)
return dict(entropy = entropy, norm = norm, KL = KL)
return dict(entropy = entropy, norm = norm)

def sample_diagnostics(samples, true_posteriors = {}):
def sample_diagnostics(samples, true_posteriors = {}, true_params = {}):
result = {}
for params in samples['weights'].keys():
if len(params) > 1:
Expand Down

0 comments on commit fbdd291

Please sign in to comment.