-
Notifications
You must be signed in to change notification settings - Fork 158
/
Copy pathpotential_test.py
66 lines (57 loc) · 2.33 KB
/
potential_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
from __future__ import annotations
import pytest
import torch
from torch import eye, ones, zeros
from torch.distributions import MultivariateNormal
from sbi.inference import (
ImportanceSamplingPosterior,
MCMCPosterior,
RejectionPosterior,
VIPosterior,
)
@pytest.mark.parametrize(
"sampling_method",
[
ImportanceSamplingPosterior,
pytest.param(MCMCPosterior, marks=pytest.mark.mcmc),
RejectionPosterior,
VIPosterior,
],
)
def test_callable_potential(sampling_method, mcmc_params_accurate: dict):
"""Test whether callable potentials can be used to sample from a Gaussian."""
dim = 2
mean = 2.5
cov = 2.0
x_o = 1 * ones((dim,))
target_density = MultivariateNormal(mean * ones((dim,)), cov * eye(dim))
def potential(theta, x_o):
return target_density.log_prob(theta + x_o)
proposal = MultivariateNormal(zeros((dim,)), 5 * eye(dim))
if sampling_method == ImportanceSamplingPosterior:
approx_density = sampling_method(
potential_fn=potential, proposal=proposal, method="sir"
)
approx_samples = approx_density.sample((1024,), oversampling_factor=1024, x=x_o)
elif sampling_method == MCMCPosterior:
approx_density = sampling_method(potential_fn=potential, proposal=proposal)
approx_samples = approx_density.sample(
(1024,), x=x_o, method="slice_np_vectorized", **mcmc_params_accurate
)
elif sampling_method == VIPosterior:
approx_density = sampling_method(
potential_fn=potential, prior=proposal
).set_default_x(x_o)
approx_density = approx_density.train()
approx_samples = approx_density.sample((1024,))
elif sampling_method == RejectionPosterior:
approx_density = sampling_method(
potential_fn=potential, proposal=proposal
).set_default_x(x_o)
approx_samples = approx_density.sample((1024,))
sample_mean = torch.mean(approx_samples, dim=0)
sample_std = torch.std(approx_samples, dim=0)
assert torch.allclose(sample_mean, torch.as_tensor(mean) - x_o, atol=0.2)
assert torch.allclose(sample_std, torch.sqrt(torch.as_tensor(cov)), atol=0.1)