Skip to content

Commit

Permalink
Test for valid use of (S)NPE API (#24)
Browse files Browse the repository at this point in the history
  • Loading branch information
psteinb committed Nov 12, 2021
1 parent 67394a9 commit a9c20c8
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 1 deletion.
48 changes: 48 additions & 0 deletions tests/algorithms/sbi/test_snpe_posterior.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import pytest
import torch

import sbibm
from sbibm.algorithms.sbi.snpe import run as run_posterior
from sbibm.metrics.c2st import c2st


# a fast test
@pytest.mark.parametrize(
"task_name,num_observation",
[
(task_name, num_observation)
for task_name in [
"gaussian_linear",
"gaussian_linear_uniform",
]
for num_observation in [1, 3]
],
)
def test_npe_posterior(
task_name, num_observation, num_simulations=2_000, num_samples=100
):
task = sbibm.get_task(task_name)

predicted, _, _ = run_posterior(
task=task,
num_observation=num_observation,
num_simulations=num_simulations,
num_samples=num_samples,
num_rounds=1,
neural_net="mdn",
max_num_epochs=30,
)

reference_samples = task.get_reference_posterior_samples(
num_observation=num_observation
)

expected = reference_samples[:num_samples, :]

assert expected.shape == predicted.shape

acc = c2st(predicted, expected)

assert acc > 0.5
assert acc < 1.0
assert acc > 0.6
2 changes: 1 addition & 1 deletion tests/algorithms/test_baseline_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"gaussian_linear_uniform",
"slcp",
]
for num_observation in range(1, 11)
for num_observation in [1, 2]
],
)
def test_posterior(
Expand Down

0 comments on commit a9c20c8

Please sign in to comment.